Python进阶与实战:从并发编程到完整项目构建
1. 多线程与多进程:解锁Python的并发潜力
在Python的世界里,GIL(全局解释器锁) 是个绕不开的话题。正是因为这个设计,Python的多线程并不适合CPU密集型任务,但对于I/O密集型任务却非常有效。而多进程则能真正利用多核CPU的优势。
让我通过一个实际案例来展示它们的区别:
python
import threading
import multiprocessing
import time
import requests
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from queue import Queue
print("=== Python并发编程实战 ===")
# 1. 多线程基础:爬虫任务示例
class ThreadingDemo:
"""多线程实战:模拟网页爬取"""
def __init__(self, urls):
self.urls = urls
self.results = []
self.lock = threading.Lock()
def fetch_url(self, url):
"""模拟网络请求"""
try:
# 模拟不同页面的加载时间
delay = hash(url) % 3 + 1
time.sleep(delay)
# 模拟获取页面内容
content_length = len(url) * 100
with self.lock:
self.results.append({
'url': url,
'status': 'success',
'length': content_length,
'thread': threading.current_thread().name
})
print(f"[Thread] {threading.current_thread().name} 完成 {url}")
except Exception as e:
with self.lock:
self.results.append({
'url': url,
'status': 'error',
'error': str(e),
'thread': threading.current_thread().name
})
def run_threading(self):
"""使用传统线程方式"""
print("\n1. 传统线程方式执行")
start_time = time.time()
threads = []
for url in self.urls:
thread = threading.Thread(target=self.fetch_url, args=(url,))
threads.append(thread)
thread.start()
# 等待所有线程完成
for thread in threads:
thread.join()
elapsed = time.time() - start_time
print(f"传统线程方式耗时: {elapsed:.2f}秒")
print(f"成功获取: {len([r for r in self.results if r['status'] == 'success'])}个页面")
return elapsed
def run_threadpool(self):
"""使用线程池"""
print("\n2. 线程池方式执行")
start_time = time.time()
with ThreadPoolExecutor(max_workers=5) as executor:
# 提交所有任务
futures = [executor.submit(self.fetch_url, url) for url in self.urls]
# 等待所有任务完成
for future in futures:
future.result()
elapsed = time.time() - start_time
print(f"线程池方式耗时: {elapsed:.2f}秒")
return elapsed
# 2. 多进程实战:计算密集型任务
class ProcessingDemo:
"""多进程实战:CPU密集型计算"""
@staticmethod
def cpu_intensive_task(n):
"""模拟CPU密集型任务:计算斐波那契数列"""
def fibonacci(x):
if x <= 1:
return x
return fibonacci(x-1) + fibonacci(x-2)
start = time.time()
result = fibonacci(n)
elapsed = time.time() - start
return {
'input': n,
'result': result,
'process': multiprocessing.current_process().name,
'time': elapsed
}
def run_processing(self, numbers):
"""使用多进程执行"""
print("\n3. 多进程方式执行 (CPU密集型)")
# 单进程基准
print("单进程基准测试...")
start_time = time.time()
single_results = []
for n in numbers:
single_results.append(self.cpu_intensive_task(n))
single_elapsed = time.time() - start_time
print(f"单进程耗时: {single_elapsed:.2f}秒")
# 多进程执行
print("\n多进程执行...")
start_time = time.time()
with ProcessPoolExecutor(max_workers=4) as executor:
futures = [executor.submit(self.cpu_intensive_task, n) for n in numbers]
multi_results = [future.result() for future in futures]
multi_elapsed = time.time() - start_time
print(f"多进程耗时: {multi_elapsed:.2f}秒")
print(f"加速比: {single_elapsed/multi_elapsed:.2f}倍")
return single_elapsed, multi_elapsed, multi_results
# 3. 生产者-消费者模式
class ProducerConsumerDemo:
"""生产者-消费者模式:线程安全队列"""
def __init__(self, max_size=10):
self.queue = Queue(maxsize=max_size)
self.results = []
self.producer_done = False
def producer(self, items):
"""生产者:生成数据"""
for i, item in enumerate(items):
# 模拟生产耗时
time.sleep(0.1)
self.queue.put(item)
print(f"[Producer] 生产: {item}")
# 生产结束信号
self.queue.put(None)
print("[Producer] 生产完成")
def consumer(self, consumer_id):
"""消费者:处理数据"""
while True:
item = self.queue.get()
# 检查结束信号
if item is None:
self.queue.put(None) # 让其他消费者也能结束
print(f"[Consumer-{consumer_id}] 消费完成")
break
# 模拟消费耗时
time.sleep(0.2)
result = f"处理后的-{item}"
self.results.append((consumer_id, result))
print(f"[Consumer-{consumer_id}] 消费: {item} -> {result}")
self.queue.task_done()
def run(self, items, num_consumers=2):
"""运行生产者-消费者模式"""
print("\n4. 生产者-消费者模式")
# 创建生产者线程
producer_thread = threading.Thread(
target=self.producer,
args=(items,),
name="Producer"
)
# 创建消费者线程
consumer_threads = []
for i in range(num_consumers):
thread = threading.Thread(
target=self.consumer,
args=(i+1,),
name=f"Consumer-{i+1}"
)
consumer_threads.append(thread)
# 启动所有线程
start_time = time.time()
producer_thread.start()
for thread in consumer_threads:
thread.start()
# 等待完成
producer_thread.join()
for thread in consumer_threads:
thread.join()
elapsed = time.time() - start_time
print(f"生产者-消费者模式耗时: {elapsed:.2f}秒")
print(f"处理结果数量: {len(self.results)}")
# 4. 线程/进程通信
class CommunicationDemo:
"""线程和进程间通信示例"""
@staticmethod
def thread_communication():
"""线程间通信:使用Queue"""
print("\n5. 线程间通信")
def worker(input_queue, output_queue, worker_id):
while True:
task = input_queue.get()
if task is None:
input_queue.put(None) # 让其他worker也能结束
break
# 处理任务
result = f"Worker-{worker_id}处理:{task*2}"
output_queue.put(result)
print(f"[Worker-{worker_id}] {task} -> {result}")
input_queue.task_done()
# 创建队列
task_queue = Queue()
result_queue = Queue()
# 提交任务
for i in range(10):
task_queue.put(i)
task_queue.put(None) # 结束信号
# 创建工作线程
workers = []
for i in range(3):
worker_thread = threading.Thread(
target=worker,
args=(task_queue, result_queue, i+1)
)
workers.append(worker_thread)
worker_thread.start()
# 收集结果
results = []
for _ in range(10):
results.append(result_queue.get())
# 等待所有工作线程完成
for worker_thread in workers:
worker_thread.join()
print(f"处理结果: {results[:3]}...") # 只显示前3个
@staticmethod
def process_communication():
"""进程间通信:使用Manager"""
print("\n6. 进程间通信")
def process_worker(shared_list, process_id):
"""进程工作函数"""
import os
pid = os.getpid()
# 每个进程添加一些数据
for i in range(3):
shared_list.append(f"进程{process_id}(PID:{pid})-数据{i}")
return process_id, pid
from multiprocessing import Manager
# 创建共享数据
with Manager() as manager:
shared_list = manager.list()
# 创建进程池
with ProcessPoolExecutor(max_workers=3) as executor:
# 提交任务
futures = [
executor.submit(process_worker, shared_list, i+1)
for i in range(3)
]
# 获取结果
process_results = [future.result() for future in futures]
print(f"共享数据: {list(shared_list)}")
print(f"进程结果: {process_results}")
# 运行示例
if __name__ == "__main__":
# 准备测试数据
urls = [
f"https://example.com/page{i}"
for i in range(1, 11)
]
# 多线程演示
print("=" * 50)
print("多线程演示")
print("=" * 50)
thread_demo = ThreadingDemo(urls)
thread_time = thread_demo.run_threading()
threadpool_time = thread_demo.run_threadpool()
# 多进程演示
print("\n" + "=" * 50)
print("多进程演示")
print("=" * 50)
process_demo = ProcessingDemo()
numbers = [30, 31, 32, 33] # 计算斐波那契数列
single_time, multi_time, results = process_demo.run_processing(numbers)
# 生产者-消费者模式
print("\n" + "=" * 50)
print("生产者-消费者模式")
print("=" * 50)
pc_demo = ProducerConsumerDemo(max_size=5)
items = [f"任务-{i}" for i in range(1, 11)]
pc_demo.run(items, num_consumers=3)
# 通信演示
print("\n" + "=" * 50)
print("进程/线程通信")
print("=" * 50)
comm_demo = CommunicationDemo()
comm_demo.thread_communication()
comm_demo.process_communication()
# 性能对比总结
print("\n" + "=" * 50)
print("性能对比总结")
print("=" * 50)
print(f"多线程 vs 线程池: 线程池快 {thread_time/threadpool_time:.2f}倍")
print(f"多进程加速比: {single_time/multi_time:.2f}倍")
print("\n使用建议:")
print("1. I/O密集型任务 -> 使用多线程或异步")
print("2. CPU密集型任务 -> 使用多进程")
print("3. 需要共享状态 -> 多线程 + 锁机制")
print("4. 需要真正并行 -> 多进程")关键要点:
- 多线程适合I/O密集型任务,如网络请求、文件读写
- 多进程适合CPU密集型任务,如数学计算、图像处理
- 使用
concurrent.futures模块可以简化并发编程 - 注意线程安全和进程间通信的问题
2. 异步编程:asyncio让I/O操作飞起来
异步编程是处理高并发I/O操作的利器。Python的asyncio库提供了完善的异步编程支持。
python
import asyncio
import aiohttp
import time
import json
from datetime import datetime
import signal
import sys
print("=== 异步编程深度探索 ===")
# 1. 异步基础:协程与任务
class AsyncBasics:
"""异步编程基础示例"""
@staticmethod
async def basic_coroutines():
"""基础协程示例"""
print("1. 基础协程")
async def say_hello(name, delay):
"""一个简单的协程"""
await asyncio.sleep(delay)
return f"Hello, {name}!"
# 创建并运行协程
result = await say_hello("Python", 1)
print(f"结果: {result}")
# 创建多个任务
print("\n运行多个任务:")
tasks = [
say_hello("Alice", 1),
say_hello("Bob", 2),
say_hello("Charlie", 3)
]
results = await asyncio.gather(*tasks)
print(f"所有结果: {results}")
@staticmethod
async def task_management():
"""任务管理示例"""
print("\n2. 任务管理")
async def long_running_task(task_id, duration):
print(f"任务 {task_id} 开始")
try:
await asyncio.sleep(duration)
print(f"任务 {task_id} 完成")
return f"任务 {task_id} 结果"
except asyncio.CancelledError:
print(f"任务 {task_id} 被取消")
raise
# 创建任务
task1 = asyncio.create_task(long_running_task(1, 5))
task2 = asyncio.create_task(long_running_task(2, 3))
# 等待一段时间后取消任务1
await asyncio.sleep(2)
task1.cancel()
try:
# 等待任务完成(或取消)
await task1
except asyncio.CancelledError:
print("任务1已取消")
# 等待任务2完成
result2 = await task2
print(f"任务2结果: {result2}")
# 2. 异步网络请求
class AsyncHTTPClient:
"""异步HTTP客户端"""
def __init__(self):
self.session = None
async def __aenter__(self):
self.session = aiohttp.ClientSession()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.session.close()
async def fetch_url(self, url, session):
"""获取单个URL"""
try:
async with session.get(url, timeout=10) as response:
if response.status == 200:
text = await response.text()
return {
'url': url,
'status': 'success',
'length': len(text),
'status_code': response.status
}
else:
return {
'url': url,
'status': 'error',
'status_code': response.status
}
except Exception as e:
return {
'url': url,
'status': 'exception',
'error': str(e)
}
async def concurrent_fetch(self, urls, max_concurrent=10):
"""并发获取多个URL"""
print("\n3. 异步网络请求")
# 创建信号量控制并发数
semaphore = asyncio.Semaphore(max_concurrent)
async def fetch_with_semaphore(url):
async with semaphore:
return await self.fetch_url(url, self.session)
# 创建所有任务
tasks = [fetch_with_semaphore(url) for url in urls]
# 显示进度
print(f"开始获取 {len(urls)} 个URL...")
start_time = time.time()
# 分批显示进度
for i in range(0, len(tasks), max_concurrent):
batch = tasks[i:i+max_concurrent]
results = await asyncio.gather(*batch)
successful = sum(1 for r in results if r['status'] == 'success')
print(f"批次 {i//max_concurrent + 1}: 成功 {successful}/{len(batch)}")
# 获取所有结果
all_results = await asyncio.gather(*tasks)
elapsed = time.time() - start_time
# 统计结果
success_count = sum(1 for r in all_results if r['status'] == 'success')
print(f"\n完成! 耗时: {elapsed:.2f}秒")
print(f"成功: {success_count}/{len(urls)}")
print(f"平均每个请求: {elapsed/len(urls):.3f}秒")
return all_results
async def api_rate_limited(self, api_endpoints, requests_per_second=5):
"""速率限制的API调用"""
print("\n4. 速率限制请求")
async def call_api(endpoint):
# 模拟API调用
await asyncio.sleep(0.5) # API处理时间
return {
'endpoint': endpoint,
'data': f"来自 {endpoint} 的响应",
'timestamp': datetime.now().isoformat()
}
# 使用信号量进行速率限制
semaphore = asyncio.Semaphore(requests_per_second)
async def rate_limited_call(endpoint):
async with semaphore:
# 控制请求速率
await asyncio.sleep(1/requests_per_second)
return await call_api(endpoint)
tasks = [rate_limited_call(endpoint) for endpoint in api_endpoints]
results = await asyncio.gather(*tasks)
print(f"完成 {len(results)} 个API调用")
return results
# 3. 异步生产者-消费者模式
class AsyncProducerConsumer:
"""异步生产者-消费者模式"""
def __init__(self, queue_size=100):
self.queue = asyncio.Queue(maxsize=queue_size)
self.producers_done = 0
self.total_producers = 0
async def producer(self, producer_id, items):
"""异步生产者"""
self.total_producers += 1
for item in items:
# 模拟生产耗时
await asyncio.sleep(0.1)
# 生产项目
produced_item = f"Producer-{producer_id}:{item}"
await self.queue.put(produced_item)
print(f"[Producer-{producer_id}] 生产: {produced_item}")
self.producers_done += 1
print(f"[Producer-{producer_id}] 生产完成")
async def consumer(self, consumer_id):
"""异步消费者"""
while True:
try:
# 设置超时,防止无限等待
item = await asyncio.wait_for(self.queue.get(), timeout=1.0)
# 模拟消费耗时
await asyncio.sleep(0.2)
consumed_item = f"Consumer-{consumer_id}处理:{item}"
print(f"[Consumer-{consumer_id}] 消费: {item}")
self.queue.task_done()
except asyncio.TimeoutError:
# 检查是否所有生产者都已完成
if self.producers_done >= self.total_producers and self.queue.empty():
print(f"[Consumer-{consumer_id}] 所有任务完成,退出")
break
continue
async def run(self, producers_data, num_consumers=3):
"""运行异步生产者-消费者"""
print("\n5. 异步生产者-消费者模式")
# 创建消费者任务
consumer_tasks = [
asyncio.create_task(self.consumer(i+1))
for i in range(num_consumers)
]
# 创建生产者任务
producer_tasks = []
for i, items in enumerate(producers_data):
task = asyncio.create_task(self.producer(i+1, items))
producer_tasks.append(task)
# 等待所有生产者完成
await asyncio.gather(*producer_tasks)
print("所有生产者已完成")
# 等待队列清空
await self.queue.join()
print("队列已清空")
# 取消消费者任务
for task in consumer_tasks:
task.cancel()
# 等待消费者任务完成取消
await asyncio.gather(*consumer_tasks, return_exceptions=True)
print("所有消费者已完成")
# 4. 异步Web服务器
class AsyncWebServer:
"""简易异步Web服务器"""
@staticmethod
async def handle_request(reader, writer):
"""处理HTTP请求"""
# 读取请求
request = await reader.read(4096)
request_text = request.decode('utf-8')
# 解析请求行
request_lines = request_text.split('\r\n')
if len(request_lines) > 0:
request_line = request_lines[0]
method, path, version = request_line.split(' ')
else:
method, path, version = 'GET', '/', 'HTTP/1.1'
# 获取客户端地址
addr = writer.get_extra_info('peername')
print(f"[{datetime.now()}] {addr[0]}:{addr[1]} - {method} {path}")
# 根据路径返回响应
if path == '/':
response_body = "Hello, Async World!"
content_type = 'text/plain'
elif path == '/time':
response_body = datetime.now().isoformat()
content_type = 'text/plain'
elif path.startswith('/echo/'):
response_body = path[6:] # 移除 '/echo/'
content_type = 'text/plain'
elif path == '/json':
response_body = json.dumps({
'message': 'Hello from async server!',
'timestamp': datetime.now().isoformat(),
'status': 'success'
})
content_type = 'application/json'
else:
response_body = "404 Not Found"
content_type = 'text/plain'
# 构建响应
response = f"""HTTP/1.1 200 OK
Content-Type: {content_type}; charset=utf-8
Content-Length: {len(response_body)}
Connection: close
{response_body}"""
# 发送响应
writer.write(response.encode('utf-8'))
await writer.drain()
writer.close()
@staticmethod
async def start_server(host='127.0.0.1', port=8888):
"""启动Web服务器"""
print(f"\n6. 启动异步Web服务器: http://{host}:{port}")
server = await asyncio.start_server(
AsyncWebServer.handle_request,
host,
port
)
addr = server.sockets[0].getsockname()
print(f'服务器运行在 {addr}')
async with server:
await server.serve_forever()
# 5. 异步上下文管理器
class AsyncDatabaseConnection:
"""异步数据库连接示例"""
def __init__(self, connection_string):
self.connection_string = connection_string
self.connected = False
async def __aenter__(self):
"""进入异步上下文"""
print(f"连接到数据库: {self.connection_string}")
await asyncio.sleep(1) # 模拟连接耗时
self.connected = True
print("数据库连接成功")
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""退出异步上下文"""
if self.connected:
print("关闭数据库连接")
await asyncio.sleep(0.5) # 模拟关闭耗时
self.connected = False
async def query(self, sql):
"""执行查询"""
if not self.connected:
raise RuntimeError("未连接到数据库")
print(f"执行查询: {sql}")
await asyncio.sleep(0.3) # 模拟查询耗时
# 模拟返回结果
return [
{"id": 1, "name": "Alice", "age": 25},
{"id": 2, "name": "Bob", "age": 30},
{"id": 3, "name": "Charlie", "age": 35}
]
async def execute(self, sql):
"""执行更新"""
if not self.connected:
raise RuntimeError("未连接到数据库")
print(f"执行更新: {sql}")
await asyncio.sleep(0.2) # 模拟执行耗时
return {"rows_affected": 1}
# 主异步函数
async def main():
"""主异步函数"""
print("=" * 50)
print("异步编程实战")
print("=" * 50)
# 1. 基础示例
basics = AsyncBasics()
await basics.basic_coroutines()
await basics.task_management()
# 2. 异步HTTP客户端
urls = [
"https://httpbin.org/get",
"https://httpbin.org/status/200",
"https://httpbin.org/status/404",
"https://httpbin.org/delay/2",
"https://httpbin.org/headers"
] * 2 # 重复一次,总共10个URL
async with AsyncHTTPClient() as client:
# 并发请求
results = await client.concurrent_fetch(urls, max_concurrent=3)
# API调用(带速率限制)
api_endpoints = [f"/api/v1/users/{i}" for i in range(1, 11)]
api_results = await client.api_rate_limited(api_endpoints, requests_per_second=3)
# 3. 异步生产者-消费者
pc = AsyncProducerConsumer(queue_size=5)
producers_data = [
[f"Item-{i}-{j}" for j in range(1, 4)]
for i in range(1, 4)
]
await pc.run(producers_data, num_consumers=2)
# 4. 异步数据库操作
print("\n7. 异步数据库操作")
async with AsyncDatabaseConnection("postgresql://user:pass@localhost/db") as db:
# 查询数据
users = await db.query("SELECT * FROM users LIMIT 3")
print(f"查询结果: {users}")
# 执行更新
result = await db.execute("UPDATE users SET age = 26 WHERE id = 1")
print(f"更新结果: {result}")
print("\n" + "=" * 50)
print("异步编程示例完成!")
print("=" * 50)
# 运行异步主函数
if __name__ == "__main__":
# 设置事件循环策略(Windows需要)
if sys.platform == 'win32':
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
# 运行主异步函数
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\n程序被用户中断")异步编程最佳实践:
- 使用
async/await语法,避免直接使用回调 - 合理控制并发数量,避免资源耗尽
- 使用
asyncio.gather()并行执行多个任务 - 注意异常处理,使用
try...except包装await调用 - 对于CPU密集型任务,仍然应该使用多进程
3. 装饰器与元编程:Python的魔法时刻
装饰器和元编程是Python中最强大的特性之一,它们让代码更加灵活和可重用。
python
import time
import functools
from datetime import datetime
from typing import Any, Callable, Type, TypeVar
from contextlib import ContextDecorator
import inspect
print("=== Python装饰器与元编程深度探索 ===")
# 1. 基础装饰器
class DecoratorBasics:
"""装饰器基础知识"""
@staticmethod
def timer(func: Callable) -> Callable:
"""计时装饰器:测量函数执行时间"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
print(f"函数 {func.__name__} 执行时间: {end_time - start_time:.4f}秒")
return result
return wrapper
@staticmethod
def debug(func: Callable) -> Callable:
"""调试装饰器:记录函数调用信息"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
print(f"[DEBUG] 调用函数: {func.__name__}")
print(f"[DEBUG] 位置参数: {args}")
print(f"[DEBUG] 关键字参数: {kwargs}")
result = func(*args, **kwargs)
print(f"[DEBUG] 返回值: {result}")
return result
return wrapper
@staticmethod
def retry(max_attempts: int = 3, delay: float = 1.0):
"""重试装饰器:失败时自动重试"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
for attempt in range(1, max_attempts + 1):
try:
return func(*args, **kwargs)
except Exception as e:
if attempt == max_attempts:
print(f"函数 {func.__name__} 失败,已达最大重试次数")
raise
print(f"函数 {func.__name__} 第{attempt}次失败: {e}")
print(f"{delay}秒后重试...")
time.sleep(delay)
return None
return wrapper
return decorator
# 2. 类装饰器
class ClassDecorators:
"""类装饰器示例"""
@staticmethod
def singleton(cls):
"""单例装饰器:确保类只有一个实例"""
instances = {}
@functools.wraps(cls)
def wrapper(*args, **kwargs):
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
print(f"创建 {cls.__name__} 的新实例")
else:
print(f"返回 {cls.__name__} 的现有实例")
return instances[cls]
return wrapper
@staticmethod
def add_methods(methods_dict):
"""动态添加方法装饰器"""
def decorator(cls):
for method_name, method_func in methods_dict.items():
setattr(cls, method_name, method_func)
return cls
return decorator
@staticmethod
def auto_register(registry=None):
"""自动注册装饰器"""
if registry is None:
registry = {}
def decorator(cls):
registry[cls.__name__] = cls
cls.registry = registry
return cls
decorator.registry = registry
return decorator
# 3. 带参数的装饰器
class ParametricDecorators:
"""带参数的装饰器"""
@staticmethod
def rate_limited(max_per_second):
"""速率限制装饰器"""
min_interval = 1.0 / max_per_second
def decorator(func):
last_called = [0.0]
@functools.wraps(func)
def wrapper(*args, **kwargs):
elapsed = time.time() - last_called[0]
left_to_wait = min_interval - elapsed
if left_to_wait > 0:
time.sleep(left_to_wait)
last_called[0] = time.time()
return func(*args, **kwargs)
return wrapper
return decorator
@staticmethod
def validate_input(*validators):
"""输入验证装饰器"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# 验证位置参数
for i, (arg, validator) in enumerate(zip(args, validators)):
if callable(validator) and not validator(arg):
raise ValueError(f"参数 {i} 验证失败: {arg}")
# 这里可以添加关键字参数验证
return func(*args, **kwargs)
return wrapper
return decorator
@staticmethod
def cache_results(max_size=128):
"""缓存结果装饰器(带LRU淘汰)"""
def decorator(func):
cache = {}
cache_keys = []
@functools.wraps(func)
def wrapper(*args, **kwargs):
# 创建缓存键
cache_key = (args, tuple(sorted(kwargs.items())))
if cache_key in cache:
print(f"[CACHE HIT] {func.__name__}{args}")
return cache[cache_key]
print(f"[CACHE MISS] {func.__name__}{args}")
result = func(*args, **kwargs)
# 添加到缓存
cache[cache_key] = result
cache_keys.append(cache_key)
# 如果超过最大大小,移除最旧的
if len(cache) > max_size:
oldest_key = cache_keys.pop(0)
del cache[oldest_key]
return result
wrapper.cache_clear = lambda: (cache.clear(), cache_keys.clear())
wrapper.cache_info = lambda: {
'size': len(cache),
'max_size': max_size,
'hits': sum(1 for _ in cache_keys)
}
return wrapper
return decorator
# 4. 上下文管理器装饰器
class ContextManagerDecorators:
"""上下文管理器装饰器"""
class timed_block(ContextDecorator):
"""计时上下文管理器"""
def __enter__(self):
self.start_time = time.time()
print("开始计时...")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.end_time = time.time()
elapsed = self.end_time - self.start_time
print(f"代码块执行时间: {elapsed:.4f}秒")
return False
@staticmethod
def as_context_manager(func):
"""将函数转换为上下文管理器"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
class FunctionContextManager:
def __enter__(self):
self.result = func(*args, **kwargs)
return self.result
def __exit__(self, exc_type, exc_val, exc_tb):
# 可以在这里添加清理代码
pass
return FunctionContextManager()
return wrapper
# 5. 元类编程
class MetaProgramming:
"""元类编程示例"""
class SingletonMeta(type):
"""单例元类"""
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]
class AutoRegisterMeta(type):
"""自动注册元类"""
def __init__(cls, name, bases, attrs):
super().__init__(name, bases, attrs)
if not hasattr(cls, 'registry'):
cls.registry = {}
cls.registry[name] = cls
class ValidateFieldsMeta(type):
"""字段验证元类"""
def __new__(mcs, name, bases, attrs):
# 收集所有带验证器的字段
validators = {}
for attr_name, attr_value in attrs.items():
if hasattr(attr_value, '_validator'):
validators[attr_name] = attr_value._validator
# 创建类
cls = super().__new__(mcs, name, bases, attrs)
cls._validators = validators
# 重写 __setattr__ 以进行验证
original_setattr = cls.__setattr__
def new_setattr(self, name, value):
if name in self._validators:
validator = self._validators[name]
if not validator(value):
raise ValueError(f"字段 {name} 验证失败: {value}")
original_setattr(self, name, value)
cls.__setattr__ = new_setattr
return cls
@staticmethod
def validator(func):
"""验证器装饰器(用于字段)"""
func._validator = True
return func
# 6. 描述符
class Descriptors:
"""描述符示例"""
class ValidatedAttribute:
"""验证属性描述符"""
def __init__(self, validator=None):
self.validator = validator
self.data = {}
def __get__(self, obj, objtype=None):
if obj is None:
return self
return self.data.get(id(obj))
def __set__(self, obj, value):
if self.validator and not self.validator(value):
raise ValueError(f"值验证失败: {value}")
self.data[id(obj)] = value
def __delete__(self, obj):
if id(obj) in self.data:
del self.data[id(obj)]
class CachedProperty:
"""缓存属性描述符"""
def __init__(self, func):
self.func = func
self.cache = {}
def __get__(self, obj, objtype=None):
if obj is None:
return self
cache_key = id(obj)
if cache_key not in self.cache:
self.cache[cache_key] = self.func(obj)
return self.cache[cache_key]
def __set__(self, obj, value):
raise AttributeError("缓存属性是只读的")
class ObservableAttribute:
"""可观察属性描述符"""
def __init__(self, default=None):
self.default = default
self.data = {}
self.observers = {}
def __get__(self, obj, objtype=None):
if obj is None:
return self
return self.data.get(id(obj), self.default)
def __set__(self, obj, value):
old_value = self.data.get(id(obj))
self.data[id(obj)] = value
# 通知观察者
if id(obj) in self.observers:
for observer in self.observers[id(obj)]:
observer(obj, old_value, value)
def add_observer(self, obj, observer):
"""添加观察者"""
if id(obj) not in self.observers:
self.observers[id(obj)] = []
self.observers[id(obj)].append(observer)
# 7. 综合示例:使用装饰器和元类构建ORM框架
class MiniORM:
"""迷你ORM框架示例"""
class Field:
"""字段基类"""
def __init__(self, field_type, nullable=True, default=None):
self.field_type = field_type
self.nullable = nullable
self.default = default
def validate(self, value):
"""验证字段值"""
if value is None:
return self.nullable
if not isinstance(value, self.field_type):
try:
# 尝试类型转换
value = self.field_type(value)
except (ValueError, TypeError):
return False
return True
class IntegerField(Field):
def __init__(self, nullable=True, default=None):
super().__init__(int, nullable, default)
class StringField(Field):
def __init__(self, max_length=255, nullable=True, default=None):
super().__init__(str, nullable, default)
self.max_length = max_length
def validate(self, value):
if not super().validate(value):
return False
if value is not None and len(value) > self.max_length:
return False
return True
class ModelMeta(type):
"""模型元类"""
def __new__(mcs, name, bases, attrs):
# 收集字段
fields = {}
for attr_name, attr_value in attrs.items():
if isinstance(attr_value, MiniORM.Field):
fields[attr_name] = attr_value
# 从基类继承字段
for base in bases:
if hasattr(base, '_fields'):
fields.update(base._fields)
# 创建类
attrs['_fields'] = fields
attrs['_table_name'] = attrs.get('__tablename__', name.lower())
cls = super().__new__(mcs, name, bases, attrs)
return cls
class Model(metaclass=ModelMeta):
"""模型基类"""
def __init__(self, **kwargs):
# 设置字段值
for field_name, field in self._fields.items():
value = kwargs.get(field_name, field.default)
setattr(self, field_name, value)
# 验证所有字段
self.validate()
def validate(self):
"""验证所有字段"""
errors = []
for field_name, field in self._fields.items():
value = getattr(self, field_name, None)
if not field.validate(value):
errors.append(f"字段 {field_name} 验证失败: {value}")
if errors:
raise ValueError("; ".join(errors))
def to_dict(self):
"""转换为字典"""
result = {}
for field_name in self._fields:
result[field_name] = getattr(self, field_name, None)
return result
@classmethod
def from_dict(cls, data):
"""从字典创建实例"""
return cls(**data)
def __repr__(self):
field_values = []
for field_name in self._fields:
value = getattr(self, field_name, None)
field_values.append(f"{field_name}={repr(value)}")
return f"{self.__class__.__name__}({', '.join(field_values)})"
# 8. 装饰器应用示例
def demonstration():
"""装饰器和元编程演示"""
# 使用基础装饰器
@DecoratorBasics.timer
@DecoratorBasics.debug
def calculate_sum(n):
"""计算1到n的和"""
return sum(range(1, n + 1))
print("1. 基础装饰器示例:")
result = calculate_sum(100)
print(f"结果: {result}")
# 使用重试装饰器
@DecoratorBasics.retry(max_attempts=3, delay=1)
def unreliable_function():
"""不可靠的函数(有时会失败)"""
import random
if random.random() < 0.7:
raise ValueError("随机失败!")
return "成功!"
print("\n2. 重试装饰器示例:")
try:
result = unreliable_function()
print(f"最终结果: {result}")
except Exception as e:
print(f"最终失败: {e}")
# 使用单例装饰器
@ClassDecorators.singleton
class DatabaseConnection:
def __init__(self, connection_string):
self.connection_string = connection_string
print(f"初始化数据库连接: {connection_string}")
print("\n3. 单例装饰器示例:")
db1 = DatabaseConnection("mysql://localhost/test")
db2 = DatabaseConnection("mysql://localhost/test")
print(f"db1 is db2: {db1 is db2}")
# 使用速率限制装饰器
@ParametricDecorators.rate_limited(max_per_second=2)
def api_call(endpoint):
"""模拟API调用"""
print(f"调用API: {endpoint}")
time.sleep(0.1) # 模拟网络延迟
return f"响应来自 {endpoint}"
print("\n4. 速率限制装饰器示例:")
for i in range(5):
result = api_call(f"/api/v1/users/{i}")
print(f"结果: {result}")
# 使用缓存装饰器
@ParametricDecorators.cache_results(max_size=3)
def expensive_computation(n):
"""昂贵的计算"""
print(f"执行昂贵计算: {n}")
time.sleep(1)
return n * n
print("\n5. 缓存装饰器示例:")
for i in [1, 2, 3, 1, 2, 4, 5, 1]:
result = expensive_computation(i)
print(f"计算结果: {result}")
print(f"缓存信息: {expensive_computation.cache_info()}")
# 使用上下文管理器装饰器
print("\n6. 上下文管理器装饰器示例:")
with ContextManagerDecorators.timed_block():
time.sleep(0.5)
print("在计时块中执行代码")
# 使用迷你ORM
print("\n7. 迷你ORM示例:")
class User(MiniORM.Model):
__tablename__ = 'users'
id = MiniORM.IntegerField(nullable=False)
name = MiniORM.StringField(max_length=100, nullable=False)
email = MiniORM.StringField(max_length=255)
age = MiniORM.IntegerField(nullable=True, default=18)
# 创建用户实例
user = User(id=1, name="张三", email="zhangsan@example.com", age=25)
print(f"用户对象: {user}")
print(f"用户字典: {user.to_dict()}")
# 验证失败示例
try:
invalid_user = User(id="不是数字", name="李四")
except ValueError as e:
print(f"验证失败: {e}")
# 使用描述符
print("\n8. 描述符示例:")
class Person:
age = Descriptors.ValidatedAttribute(
validator=lambda x: isinstance(x, int) and 0 <= x <= 150
)
@Descriptors.CachedProperty
def birth_year(self):
"""计算出生年份(缓存)"""
print("计算出生年份...")
current_year = datetime.now().year
return current_year - self.age
person = Person()
person.age = 30
print(f"年龄: {person.age}")
print(f"出生年份: {person.birth_year}")
print(f"再次获取出生年份: {person.birth_year}") # 应该使用缓存
try:
person.age = 200 # 应该失败
except ValueError as e:
print(f"设置无效年龄: {e}")
# 使用可观察属性
print("\n9. 可观察属性示例:")
class ObservableModel:
value = Descriptors.ObservableAttribute(default=0)
def __init__(self):
self.value.add_observer(self, self.on_value_changed)
def on_value_changed(self, obj, old_value, new_value):
print(f"值从 {old_value} 变为 {new_value}")
model = ObservableModel()
model.value = 10
model.value = 20
print("\n" + "=" * 50)
print("装饰器与元编程演示完成!")
print("=" * 50)
if __name__ == "__main__":
demonstration()装饰器与元编程的核心思想:
- 装饰器是修改或增强函数/类的函数
- 元类是创建类的类,控制类的创建行为
- 描述符是管理属性访问的协议
- 合理使用这些特性可以让代码更加优雅和强大
4. 单元测试与测试驱动开发
测试是保证代码质量的重要手段。Python的unittest和pytest框架提供了完善的测试支持。
python
import unittest
import pytest
import tempfile
import json
from pathlib import Path
from unittest.mock import Mock, patch, MagicMock
from io import StringIO
import sys
print("=== Python测试实战指南 ===")
# 1. 待测试的代码
class Calculator:
"""计算器类(用于测试示例)"""
def add(self, a, b):
"""加法"""
if not (isinstance(a, (int, float)) and isinstance(b, (int, float))):
raise TypeError("参数必须是数字")
return a + b
def subtract(self, a, b):
"""减法"""
return a - b
def multiply(self, a, b):
"""乘法"""
return a * b
def divide(self, a, b):
"""除法"""
if b == 0:
raise ValueError("除数不能为零")
return a / b
def factorial(self, n):
"""阶乘"""
if not isinstance(n, int):
raise TypeError("参数必须是整数")
if n < 0:
raise ValueError("参数不能为负数")
result = 1
for i in range(2, n + 1):
result *= i
return result
class UserManager:
"""用户管理器(用于测试示例)"""
def __init__(self):
self.users = {}
self.next_id = 1
def add_user(self, username, email):
"""添加用户"""
if not username or not email:
raise ValueError("用户名和邮箱不能为空")
if '@' not in email:
raise ValueError("邮箱格式不正确")
user_id = self.next_id
self.users[user_id] = {
'id': user_id,
'username': username,
'email': email
}
self.next_id += 1
return user_id
def get_user(self, user_id):
"""获取用户"""
return self.users.get(user_id)
def delete_user(self, user_id):
"""删除用户"""
if user_id not in self.users:
raise KeyError(f"用户ID {user_id} 不存在")
return self.users.pop(user_id)
def get_all_users(self):
"""获取所有用户"""
return list(self.users.values())
class FileProcessor:
"""文件处理器(用于测试示例)"""
def __init__(self):
self.processed_files = []
def process_file(self, filepath):
"""处理文件"""
path = Path(filepath)
if not path.exists():
raise FileNotFoundError(f"文件不存在: {filepath}")
if not path.is_file():
raise ValueError(f"不是文件: {filepath}")
# 模拟文件处理
content = path.read_text(encoding='utf-8')
processed = content.upper()
# 保存处理结果
result_file = path.with_suffix('.processed.txt')
result_file.write_text(processed, encoding='utf-8')
self.processed_files.append(str(result_file))
return processed
def get_processed_count(self):
"""获取已处理的文件数量"""
return len(self.processed_files)
# 2. 使用unittest的测试
class TestCalculator(unittest.TestCase):
"""计算器测试类"""
def setUp(self):
"""每个测试方法前运行"""
self.calc = Calculator()
print(f"开始测试 {self._testMethodName}")
def tearDown(self):
"""每个测试方法后运行"""
print(f"完成测试 {self._testMethodName}")
def test_add_integers(self):
"""测试整数加法"""
result = self.calc.add(2, 3)
self.assertEqual(result, 5)
def test_add_floats(self):
"""测试浮点数加法"""
result = self.calc.add(2.5, 3.1)
self.assertAlmostEqual(result, 5.6, places=1)
def test_add_invalid_type(self):
"""测试无效类型加法"""
with self.assertRaises(TypeError):
self.calc.add("2", 3)
def test_divide_by_zero(self):
"""测试除零错误"""
with self.assertRaises(ValueError):
self.calc.divide(10, 0)
def test_factorial(self):
"""测试阶乘"""
test_cases = [
(0, 1), # 0的阶乘是1
(1, 1), # 1的阶乘是1
(5, 120), # 5的阶乘是120
]
for n, expected in test_cases:
with self.subTest(n=n):
result = self.calc.factorial(n)
self.assertEqual(result, expected)
def test_factorial_negative(self):
"""测试负数的阶乘"""
with self.assertRaises(ValueError):
self.calc.factorial(-1)
class TestUserManager(unittest.TestCase):
"""用户管理器测试类"""
def setUp(self):
self.manager = UserManager()
def test_add_user_success(self):
"""测试成功添加用户"""
user_id = self.manager.add_user("testuser", "test@example.com")
self.assertEqual(user_id, 1)
user = self.manager.get_user(user_id)
self.assertEqual(user['username'], "testuser")
self.assertEqual(user['email'], "test@example.com")
def test_add_user_empty_username(self):
"""测试空用户名"""
with self.assertRaises(ValueError):
self.manager.add_user("", "test@example.com")
def test_add_user_invalid_email(self):
"""测试无效邮箱"""
with self.assertRaises(ValueError):
self.manager.add_user("testuser", "invalid-email")
def test_get_nonexistent_user(self):
"""测试获取不存在的用户"""
user = self.manager.get_user(999)
self.assertIsNone(user)
def test_delete_user(self):
"""测试删除用户"""
user_id = self.manager.add_user("testuser", "test@example.com")
self.assertEqual(len(self.manager.get_all_users()), 1)
deleted = self.manager.delete_user(user_id)
self.assertEqual(deleted['username'], "testuser")
self.assertEqual(len(self.manager.get_all_users()), 0)
def test_delete_nonexistent_user(self):
"""测试删除不存在的用户"""
with self.assertRaises(KeyError):
self.manager.delete_user(999)
# 3. 使用pytest的测试
class TestFileProcessor:
"""文件处理器测试类(pytest风格)"""
def test_process_file_success(self, tmp_path):
"""测试成功处理文件"""
# 创建临时文件
test_file = tmp_path / "test.txt"
test_file.write_text("Hello, World!", encoding='utf-8')
processor = FileProcessor()
result = processor.process_file(str(test_file))
# 检查处理结果
assert result == "HELLO, WORLD!"
assert processor.get_processed_count() == 1
# 检查输出文件
output_file = tmp_path / "test.processed.txt"
assert output_file.exists()
assert output_file.read_text(encoding='utf-8') == "HELLO, WORLD!"
def test_process_nonexistent_file(self):
"""测试处理不存在的文件"""
processor = FileProcessor()
with pytest.raises(FileNotFoundError):
processor.process_file("/nonexistent/file.txt")
def test_process_directory(self, tmp_path):
"""测试处理目录(应该失败)"""
processor = FileProcessor()
with pytest.raises(ValueError, match="不是文件"):
processor.process_file(str(tmp_path))
# 4. Mock和Patch测试
class TestWithMocks:
"""使用Mock的测试"""
def test_mock_example(self):
"""Mock基础示例"""
# 创建Mock对象
mock_obj = Mock()
# 设置返回值
mock_obj.some_method.return_value = 42
mock_obj.another_method.return_value = "Hello"
# 调用Mock方法
assert mock_obj.some_method() == 42
assert mock_obj.another_method() == "Hello"
# 检查调用情况
mock_obj.some_method.assert_called_once()
mock_obj.another_method.assert_called_once()
def test_mock_with_side_effect(self):
"""使用side_effect的Mock"""
mock_obj = Mock()
# side_effect可以是函数或异常
mock_obj.get_value.side_effect = [1, 2, 3, ValueError("No more values")]
assert mock_obj.get_value() == 1
assert mock_obj.get_value() == 2
assert mock_obj.get_value() == 3
with pytest.raises(ValueError):
mock_obj.get_value()
@patch('builtins.print')
def test_patch_builtin(self, mock_print):
"""Patch内置函数"""
print("Hello, World!")
print("Another message")
# 检查print是否被调用
assert mock_print.call_count == 2
mock_print.assert_any_call("Hello, World!")
mock_print.assert_any_call("Another message")
@patch('os.path.exists')
def test_patch_os_function(self, mock_exists):
"""Patch OS函数"""
# 设置返回值
mock_exists.return_value = False
from os.path import exists
assert not exists("/some/path")
mock_exists.assert_called_once_with("/some/path")
# 5. 测试夹具(Fixtures)
class TestWithFixtures:
"""使用Fixture的测试"""
@pytest.fixture
def sample_data(self):
"""提供测试数据"""
return {
'numbers': [1, 2, 3, 4, 5],
'strings': ['a', 'b', 'c'],
'nested': {'key': 'value'}
}
@pytest.fixture
def calculator(self):
"""提供计算器实例"""
return Calculator()
@pytest.fixture
def temp_file(self, tmp_path):
"""提供临时文件"""
file_path = tmp_path / "test_data.txt"
file_path.write_text("Test content\nSecond line", encoding='utf-8')
return file_path
def test_with_sample_data(self, sample_data):
"""使用sample_data fixture"""
assert len(sample_data['numbers']) == 5
assert sample_data['nested']['key'] == 'value'
def test_calculator_with_fixture(self, calculator):
"""使用calculator fixture"""
assert calculator.add(2, 3) == 5
assert calculator.multiply(4, 5) == 20
def test_file_processing(self, temp_file):
"""使用temp_file fixture"""
processor = FileProcessor()
result = processor.process_file(str(temp_file))
assert "TEST CONTENT" in result
# 6. 参数化测试
class TestParameterized:
"""参数化测试"""
@pytest.mark.parametrize("a,b,expected", [
(1, 2, 3),
(0, 0, 0),
(-1, 1, 0),
(2.5, 3.5, 6.0),
])
def test_addition(self, a, b, expected):
"""参数化加法测试"""
calc = Calculator()
result = calc.add(a, b)
assert result == expected
@pytest.mark.parametrize("n,expected", [
(0, 1),
(1, 1),
(5, 120),
(10, 3628800),
])
def test_factorial(self, n, expected):
"""参数化阶乘测试"""
calc = Calculator()
result = calc.factorial(n)
assert result == expected
@pytest.mark.parametrize("username,email,should_succeed", [
("user1", "user1@example.com", True),
("", "user@example.com", False), # 空用户名
("user2", "invalid-email", False), # 无效邮箱
("user3", "", False), # 空邮箱
])
def test_add_user_validation(self, username, email, should_succeed):
"""参数化用户验证测试"""
manager = UserManager()
if should_succeed:
user_id = manager.add_user(username, email)
assert user_id == 1
else:
with pytest.raises(ValueError):
manager.add_user(username, email)
# 7. 测试覆盖率和性能测试
class TestCoverageAndPerformance:
"""测试覆盖率和性能"""
def test_performance_basic(self):
"""基础性能测试"""
import time
calc = Calculator()
# 测试阶乘性能
start_time = time.perf_counter()
for i in range(1000):
calc.factorial(10)
elapsed = time.perf_counter() - start_time
# 断言执行时间在合理范围内
assert elapsed < 1.0, f"性能测试失败: 耗时 {elapsed:.2f}秒"
print(f"性能测试通过: 耗时 {elapsed:.4f}秒")
def test_memory_usage(self):
"""内存使用测试"""
import tracemalloc
tracemalloc.start()
# 执行可能占用内存的操作
manager = UserManager()
for i in range(1000):
manager.add_user(f"user{i}", f"user{i}@example.com")
# 获取内存快照
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics('lineno')
print("\n内存使用情况:")
for stat in top_stats[:5]: # 显示前5个
print(stat)
tracemalloc.stop()
# 简单的内存检查
assert len(manager.get_all_users()) == 1000
# 8. 集成测试和端到端测试
class TestIntegration:
"""集成测试"""
def test_calculator_integration(self):
"""计算器集成测试"""
calc = Calculator()
# 组合多个操作
result = calc.add(10, 20)
result = calc.multiply(result, 2)
result = calc.divide(result, 3)
assert result == 20.0
def test_user_manager_integration(self):
"""用户管理器集成测试"""
manager = UserManager()
# 添加多个用户
user_ids = []
for i in range(3):
user_id = manager.add_user(f"user{i}", f"user{i}@example.com")
user_ids.append(user_id)
# 验证用户
assert len(manager.get_all_users()) == 3
# 删除一个用户
manager.delete_user(user_ids[1])
assert len(manager.get_all_users()) == 2
# 验证剩余用户
remaining_users = manager.get_all_users()
remaining_usernames = [u['username'] for u in remaining_users]
assert "user0" in remaining_usernames
assert "user2" in remaining_usernames
assert "user1" not in remaining_usernames
# 9. 测试运行器
def run_tests():
"""运行测试套件"""
print("=" * 50)
print("运行测试套件")
print("=" * 50)
# 创建测试加载器
loader = unittest.TestLoader()
# 创建测试套件
suite = unittest.TestSuite()
# 添加测试类
suite.addTests(loader.loadTestsFromTestCase(TestCalculator))
suite.addTests(loader.loadTestsFromTestCase(TestUserManager))
# 创建测试运行器
runner = unittest.TextTestRunner(verbosity=2)
# 运行测试
print("\n运行unittest测试:")
result = runner.run(suite)
# 统计结果
print(f"\n测试结果:")
print(f"运行测试: {result.testsRun}")
print(f"成功: {result.testsRun - len(result.failures) - len(result.errors)}")
print(f"失败: {len(result.failures)}")
print(f"错误: {len(result.errors)}")
# 运行pytest测试
print("\n" + "=" * 50)
print("运行pytest测试")
print("=" * 50)
# 注意:在实际环境中,pytest应该通过命令行运行
# 这里只是演示如何组织测试代码
return result
# 10. 测试驱动开发(TDD)示例
class TDDBankAccount:
"""TDD示例:银行账户"""
def __init__(self, initial_balance=0):
self.balance = initial_balance
self.transactions = []
def deposit(self, amount):
"""存款"""
if amount <= 0:
raise ValueError("存款金额必须大于0")
self.balance += amount
self.transactions.append(('deposit', amount))
return self.balance
def withdraw(self, amount):
"""取款"""
if amount <= 0:
raise ValueError("取款金额必须大于0")
if amount > self.balance:
raise ValueError("余额不足")
self.balance -= amount
self.transactions.append(('withdraw', amount))
return self.balance
def get_balance(self):
"""获取余额"""
return self.balance
def get_transaction_history(self):
"""获取交易历史"""
return self.transactions.copy()
class TestBankAccountTDD(unittest.TestCase):
"""银行账户TDD测试"""
def test_initial_balance(self):
"""测试初始余额"""
account = TDDBankAccount(100)
self.assertEqual(account.get_balance(), 100)
def test_deposit_positive(self):
"""测试存款正数"""
account = TDDBankAccount()
new_balance = account.deposit(50)
self.assertEqual(new_balance, 50)
self.assertEqual(account.get_balance(), 50)
def test_deposit_zero(self):
"""测试存款零"""
account = TDDBankAccount()
with self.assertRaises(ValueError):
account.deposit(0)
def test_deposit_negative(self):
"""测试存款负数"""
account = TDDBankAccount()
with self.assertRaises(ValueError):
account.deposit(-10)
def test_withdraw_success(self):
"""测试成功取款"""
account = TDDBankAccount(100)
new_balance = account.withdraw(30)
self.assertEqual(new_balance, 70)
def test_withdraw_insufficient_funds(self):
"""测试余额不足"""
account = TDDBankAccount(50)
with self.assertRaises(ValueError):
account.withdraw(100)
def test_transaction_history(self):
"""测试交易历史"""
account = TDDBankAccount(100)
account.deposit(50)
account.withdraw(30)
history = account.get_transaction_history()
self.assertEqual(len(history), 2)
self.assertEqual(history[0], ('deposit', 50))
self.assertEqual(history[1], ('withdraw', 30))
# 主函数
if __name__ == "__main__":
print("=" * 50)
print("Python测试实战指南")
print("=" * 50)
# 运行TDD示例测试
print("\nTDD示例测试:")
tdd_suite = unittest.TestLoader().loadTestsFromTestCase(TestBankAccountTDD)
tdd_runner = unittest.TextTestRunner(verbosity=1)
tdd_result = tdd_runner.run(tdd_suite)
# 运行其他测试
print("\n运行完整测试套件:")
result = run_tests()
print("\n" + "=" * 50)
print("测试完成!")
print("=" * 50)
# 退出码(用于CI/CD)
exit_code = 0 if result.wasSuccessful() else 1
print(f"\n退出码: {exit_code}")测试最佳实践:
- 遵循测试金字塔:单元测试 > 集成测试 > 端到端测试
- 测试应该独立、快速、可重复
- 使用Mock和Patch来隔离外部依赖
- 测试覆盖率不是唯一目标,测试质量更重要
- 实践测试驱动开发(TDD)可以提高代码质量
5. 代码打包与分发
将Python代码打包成可以分发的格式是每个开发者都需要掌握的技能。
python
#!/usr/bin/env python3
"""
Python包打包与分发完整示例
项目结构:
my_package/
├── setup.py # 打包配置文件
├── pyproject.toml # 现代打包配置
├── README.md # 项目说明
├── LICENSE # 许可证
├── requirements.txt # 依赖列表
├── tests/ # 测试目录
├── docs/ # 文档目录
└── src/ # 源代码目录
└── my_package/ # 包目录
├── __init__.py
├── core.py
├── utils.py
└── cli.py
"""
import os
import sys
import shutil
from pathlib import Path
from setuptools import setup, find_packages
from setuptools.command.build_py import build_py
from setuptools.command.sdist import sdist
import subprocess
print("=== Python包打包与分发实战 ===")
# 1. 创建项目目录结构
def create_project_structure():
"""创建标准的Python项目结构"""
project_name = "my_package"
print(f"1. 创建项目结构: {project_name}")
# 基础目录
directories = [
f"{project_name}/src/{project_name}",
f"{project_name}/tests",
f"{project_name}/docs",
f"{project_name}/examples",
]
for directory in directories:
os.makedirs(directory, exist_ok=True)
print(f" 创建目录: {directory}")
# 创建文件
files_to_create = {
# 包文件
f"{project_name}/src/{project_name}/__init__.py": '''"""
My Package - 一个演示用的Python包
功能:
1. 核心功能
2. 工具函数
3. CLI接口
"""
__version__ = "0.1.0"
__author__ = "Your Name"
__email__ = "your.email@example.com"
from .core import CoreClass, main_function
from .utils import helper_function, DataProcessor
from .cli import main as cli_main
__all__ = [
"CoreClass",
"main_function",
"helper_function",
"DataProcessor",
"cli_main",
]
''',
f"{project_name}/src/{project_name}/core.py": '''"""
核心模块
"""
import logging
from typing import Any, List, Dict, Optional
from dataclasses import dataclass
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class CoreClass:
"""核心类示例"""
name: str
value: int = 0
metadata: Dict[str, Any] = None
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
def process(self) -> str:
"""处理数据"""
logger.info(f"处理 {self.name}")
result = f"Processed {self.name} with value {self.value}"
if self.metadata:
result += f" and metadata {self.metadata}"
return result
def update_value(self, increment: int = 1) -> None:
"""更新值"""
self.value += increment
logger.debug(f"更新 {self.name} 的值为 {self.value}")
def main_function(input_data: List[str]) -> List[str]:
"""
主处理函数
Args:
input_data: 输入字符串列表
Returns:
处理后的字符串列表
Raises:
ValueError: 如果输入为空
Example:
>>> main_function(["a", "b"])
['A', 'B']
"""
if not input_data:
raise ValueError("输入数据不能为空")
logger.info(f"处理 {len(input_data)} 条数据")
return [item.upper() for item in input_data]
if __name__ == "__main__":
# 测试代码
obj = CoreClass("test", 42, {"key": "value"})
print(obj.process())
''',
f"{project_name}/src/{project_name}/utils.py": '''"""
工具模块
"""
import json
import hashlib
from pathlib import Path
from datetime import datetime
from typing import Union, Any
import csv
def helper_function(data: Union[str, list, dict]) -> str:
"""
辅助函数示例
Args:
data: 输入数据
Returns:
处理后的字符串
"""
if isinstance(data, str):
return data.upper()
elif isinstance(data, list):
return ', '.join(map(str, data))
elif isinstance(data, dict):
return json.dumps(data, ensure_ascii=False)
else:
return str(data)
class DataProcessor:
"""数据处理类"""
def __init__(self, data_dir: Union[str, Path] = "."):
self.data_dir = Path(data_dir)
self.processed_count = 0
def load_json(self, filename: str) -> dict:
"""加载JSON文件"""
filepath = self.data_dir / filename
if not filepath.exists():
raise FileNotFoundError(f"文件不存在: {filepath}")
with open(filepath, 'r', encoding='utf-8') as f:
return json.load(f)
def save_json(self, data: dict, filename: str) -> None:
"""保存JSON文件"""
filepath = self.data_dir / filename
# 确保目录存在
filepath.parent.mkdir(parents=True, exist_ok=True)
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
self.processed_count += 1
def calculate_hash(self, data: Any) -> str:
"""计算数据的哈希值"""
if isinstance(data, dict):
data_str = json.dumps(data, sort_keys=True)
else:
data_str = str(data)
return hashlib.sha256(data_str.encode()).hexdigest()
def process_csv(self, input_file: str, output_file: str) -> None:
"""处理CSV文件"""
input_path = self.data_dir / input_file
output_path = self.data_dir / output_file
with open(input_path, 'r', encoding='utf-8') as f_in, \
open(output_path, 'w', encoding='utf-8', newline='') as f_out:
reader = csv.DictReader(f_in)
fieldnames = reader.fieldnames
if fieldnames:
writer = csv.DictWriter(f_out, fieldnames=fieldnames)
writer.writeheader()
for row in reader:
# 处理每一行数据
processed_row = {
key: value.upper() if isinstance(value, str) else value
for key, value in row.items()
}
writer.writerow(processed_row)
self.processed_count += 1
def format_timestamp(timestamp: datetime = None) -> str:
"""格式化时间戳"""
if timestamp is None:
timestamp = datetime.now()
return timestamp.strftime("%Y-%m-%d %H:%M:%S")
''',
f"{project_name}/src/{project_name}/cli.py": '''"""
命令行接口
"""
import argparse
import sys
from typing import List
from .core import main_function, CoreClass
from .utils import DataProcessor, helper_function
def create_parser() -> argparse.ArgumentParser:
"""创建命令行解析器"""
parser = argparse.ArgumentParser(
description="My Package - 命令行工具",
epilog="示例: python -m my_package.cli process --input data.txt"
)
# 子命令
subparsers = parser.add_subparsers(dest='command', help='可用的命令')
# process 命令
process_parser = subparsers.add_parser('process', help='处理数据')
process_parser.add_argument(
'--input', '-i',
nargs='+',
required=True,
help='输入数据'
)
process_parser.add_argument(
'--output', '-o',
help='输出文件(可选)'
)
# info 命令
info_parser = subparsers.add_parser('info', help='显示包信息')
info_parser.add_argument(
'--verbose', '-v',
action='store_true',
help='显示详细信息'
)
# utils 命令
utils_parser = subparsers.add_parser('utils', help='工具函数')
utils_parser.add_argument(
'function',
choices=['hash', 'format'],
help='要使用的工具函数'
)
utils_parser.add_argument(
'data',
help='要处理的数据'
)
return parser
def handle_process(args) -> int:
"""处理process命令"""
print(f"处理 {len(args.input)} 个输入项")
try:
result = main_function(args.input)
for item in result:
print(f"结果: {item}")
if args.output:
with open(args.output, 'w', encoding='utf-8') as f:
for item in result:
f.write(f"{item}\\n")
print(f"结果已保存到: {args.output}")
return 0
except Exception as e:
print(f"处理失败: {e}", file=sys.stderr)
return 1
def handle_info(args) -> int:
"""处理info命令"""
from . import __version__, __author__
print(f"My Package 版本: {__version__}")
print(f"作者: {__author__}")
if args.verbose:
print("\\n详细功能:")
print("1. 核心功能: CoreClass, main_function")
print("2. 工具函数: helper_function, DataProcessor")
print("3. 命令行接口: 当前正在使用")
return 0
def handle_utils(args) -> int:
"""处理utils命令"""
from .utils import DataProcessor
processor = DataProcessor()
if args.function == 'hash':
hash_value = processor.calculate_hash(args.data)
print(f"哈希值: {hash_value}")
elif args.function == 'format':
formatted = helper_function(args.data)
print(f"格式化结果: {formatted}")
return 0
def main(argv: List[str] = None) -> int:
"""主函数"""
if argv is None:
argv = sys.argv[1:]
parser = create_parser()
args = parser.parse_args(argv)
if not args.command:
parser.print_help()
return 1
# 根据命令调用对应的处理函数
command_handlers = {
'process': handle_process,
'info': handle_info,
'utils': handle_utils,
}
handler = command_handlers.get(args.command)
if handler:
return handler(args)
else:
print(f"未知命令: {args.command}", file=sys.stderr)
return 1
if __name__ == "__main__":
sys.exit(main())
''',
# 测试文件
f"{project_name}/tests/__init__.py": "",
f"{project_name}/tests/test_core.py": '''"""
核心模块测试
"""
import pytest
from my_package.core import CoreClass, main_function
def test_core_class():
"""测试CoreClass"""
obj = CoreClass("test", 42, {"key": "value"})
assert obj.name == "test"
assert obj.value == 42
assert obj.metadata == {"key": "value"}
result = obj.process()
assert "test" in result
assert "42" in result
obj.update_value(10)
assert obj.value == 52
def test_main_function():
"""测试main_function"""
result = main_function(["a", "b", "c"])
assert result == ["A", "B", "C"]
# 测试空输入
with pytest.raises(ValueError):
main_function([])
def test_main_function_empty_string():
"""测试空字符串输入"""
result = main_function(["", "hello"])
assert result == ["", "HELLO"]
if __name__ == "__main__":
pytest.main([__file__])
''',
f"{project_name}/tests/test_utils.py": '''"""
工具模块测试
"""
import json
import tempfile
from pathlib import Path
from my_package.utils import helper_function, DataProcessor, format_timestamp
def test_helper_function():
"""测试helper_function"""
# 测试字符串
assert helper_function("hello") == "HELLO"
# 测试列表
assert helper_function([1, 2, 3]) == "1, 2, 3"
# 测试字典
data = {"key": "value"}
result = helper_function(data)
assert "key" in result
assert "value" in result
def test_data_processor():
"""测试DataProcessor"""
with tempfile.TemporaryDirectory() as tmpdir:
processor = DataProcessor(tmpdir)
# 测试保存和加载JSON
test_data = {"test": "data", "number": 42}
processor.save_json(test_data, "test.json")
loaded_data = processor.load_json("test.json")
assert loaded_data == test_data
# 测试哈希计算
hash_value = processor.calculate_hash(test_data)
assert len(hash_value) == 64 # SHA256哈希长度
# 测试计数器
assert processor.processed_count == 1
def test_format_timestamp():
"""测试format_timestamp"""
from datetime import datetime
timestamp = format_timestamp()
assert len(timestamp) == 19 # YYYY-MM-DD HH:MM:SS
# 测试指定时间戳
dt = datetime(2023, 1, 1, 12, 0, 0)
formatted = format_timestamp(dt)
assert formatted == "2023-01-01 12:00:00"
''',
# 配置文件
f"{project_name}/setup.py": '''"""
打包配置文件 (传统方式)
"""
from setuptools import setup, find_packages
import os
# 读取README
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
# 读取requirements
def read_requirements():
with open("requirements.txt", "r", encoding="utf-8") as f:
return [line.strip() for line in f if line.strip() and not line.startswith("#")]
setup(
name="my-package",
version="0.1.0",
author="Your Name",
author_email="your.email@example.com",
description="一个演示用的Python包",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/yourusername/my-package",
# 包发现
package_dir={"": "src"},
packages=find_packages(where="src"),
# 分类器
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
],
# Python版本要求
python_requires=">=3.7",
# 依赖
install_requires=read_requirements(),
# 可选依赖
extras_require={
"dev": [
"pytest>=6.0",
"pytest-cov>=2.0",
"black>=21.0",
"flake8>=4.0",
"mypy>=0.900",
],
"docs": [
"sphinx>=4.0",
"sphinx-rtd-theme>=1.0",
],
},
# 入口点
entry_points={
"console_scripts": [
"my-package=my_package.cli:main",
],
},
# 包含数据文件
include_package_data=True,
package_data={
"my_package": ["data/*.json", "config/*.yaml"],
},
)
''',
f"{project_name}/pyproject.toml": '''[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "my-package"
version = "0.1.0"
authors = [
{name = "Your Name", email = "your.email@example.com"}
]
description = "一个演示用的Python包"
readme = "README.md"
license = {text = "MIT"}
requires-python = ">=3.7"
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
]
dependencies = [
"requests>=2.25.0",
"click>=8.0.0",
]
[project.optional-dependencies]
dev = [
"pytest>=6.0",
"pytest-cov>=2.0",
"black>=21.0",
"flake8>=4.0",
"mypy>=0.900",
]
docs = [
"sphinx>=4.0",
"sphinx-rtd-theme>=1.0",
]
[project.scripts]
my-package = "my_package.cli:main"
[project.urls]
Homepage = "https://github.com/yourusername/my-package"
BugTracker = "https://github.com/yourusername/my-package/issues"
Documentation = "https://my-package.readthedocs.io/"
[tool.setuptools]
package-dir = {"" = "src"}
packages = {find = {where = ["src"]}}
[tool.setuptools.package-data]
"my_package" = ["data/*.json", "config/*.yaml"]
[tool.black]
line-length = 88
target-version = ['py37']
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = "-v --tb=short"
[tool.mypy]
python_version = "3.7"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = true
''',
f"{project_name}/README.md": '''# My Package
一个演示用的Python包,展示如何正确打包和分发Python项目。
## 功能特性
- ✅ 核心功能模块
- ✅ 工具函数集合
- ✅ 命令行接口
- ✅ 完整测试套件
- ✅ 类型注解支持
## 安装
### 从PyPI安装(如果已发布)
```bash
pip install my-package从源码安装
bash
# 克隆仓库
git clone https://github.com/yourusername/my-package.git
cd my-package
# 安装(开发模式)
pip install -e .
# 安装(带开发依赖)
pip install -e ".[dev]"