单元测试是保证代码质量的重要手段。pytest 是 Python 最流行的测试框架,简洁强大。本文介绍 pytest 的核心功能:fixture、parametrize 和 mock,以及如何写出可维护的测试代码。
pytest 基础用法
安装与运行
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| pip install pytest pytest-cov
pytest
pytest tests/test_user.py
pytest tests/test_user.py::test_get_user
pytest -v
pytest -x
|
断言
pytest 内置断言,失败时会显示详细信息:
1 2 3 4 5 6 7 8 9 10 11
| def test_basic(): assert 1 + 1 == 2 assert "hello".upper() == "HELLO" assert [1, 2, 3] == [1, 2, 3]
with pytest.raises(ValueError): int("not a number")
assert 0.1 + 0.2 == pytest.approx(0.3)
|
fixture 作用域与共享
fixture 是 pytest 最强大的功能之一,用于提供测试数据、setup/teardown。
基本用法
1 2 3 4 5 6 7 8 9 10
| import pytest
@pytest.fixture def user(): return {"name": "Alice", "age": 25}
def test_user_name(user): assert user["name"] == "Alice"
|
scope 作用域
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
| @pytest.fixture(scope="function") def db_connection(): conn = create_db_connection() yield conn conn.close()
@pytest.fixture(scope="class") def db_connection(): conn = create_db_connection() yield conn conn.close()
@pytest.fixture(scope="module") def config(): return load_config()
@pytest.fixture(scope="session") def browser(): driver = launch_browser() yield driver driver.quit()
|
fixture 依赖
1 2 3 4 5 6 7 8 9 10 11 12
| @pytest.fixture def db_connection(): return create_connection()
@pytest.fixture def user_repository(db_connection): return UserRepository(db_connection)
def test_save_user(user_repository): user = User(name="Alice") user_repository.save(user) assert user.id is not None
|
autouse 自动执行
1 2 3 4 5 6
| @pytest.fixture(autouse=True) def setup_teardown(): print("\nsetup") yield print("\nteardown")
|
fixture 返回字典/对象
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
| @pytest.fixture def user_dict(): return {"name": "Alice", "age": 25}
from dataclasses import dataclass
@dataclass class User: name: str age: int
@pytest.fixture def user_obj(): return User(name="Alice", age=25)
@pytest.fixture def user_factory(): def create_user(name="Bob", age=20): return User(name=name, age=age) return create_user
def test_create_user(user_factory): user = user_factory(name="Charlie", age=30) assert user.name == "Charlie"
|
parametrize 参数化
parametrize 让你用不同参数运行同一测试。
基本用法
1 2 3 4 5 6 7 8
| @pytest.mark.parametrize("input,expected", [ (1, 1), (2, 4), (3, 9), (4, 16), ]) def test_square(input, expected): assert input ** 2 == expected
|
多参数
1 2 3 4 5 6 7 8
| @pytest.mark.parametrize("a,b,expected", [ (1, 2, 3), (0, 0, 0), (-1, 1, 0), (100, 200, 300), ]) def test_add(a, b, expected): assert a + b == expected
|
组合参数
1 2 3 4 5
| @pytest.mark.parametrize("x", [1, 2, 3]) @pytest.mark.parametrize("y", [10, 20]) def test_combinations(x, y): assert x + y == x + y
|
字符串参数
1 2 3 4 5 6 7
| @pytest.mark.parametrize("email", [ "user@example.com", "test@test.cn", "admin@company.org", ]) def test_email_validation(email): assert validate_email(email) is True
|
标记失败测试
1 2 3 4 5 6 7 8
| @pytest.mark.parametrize("input,expected", [ ("2 + 2", 4), ("3 * 3", 9), pytest.param("1 / 0", 0, marks=pytest.mark.xfail), pytest.param("0 / 0", 0, marks=pytest.mark.skip("暂不支持")), ]) def test_evaluate(input, expected): assert evaluate(input) == expected
|
mock 与 patch
mock 用于隔离外部依赖,模拟不想实际调用的函数或对象。
unittest.mock 基本用法
1 2 3 4 5 6 7 8 9 10 11 12 13
| from unittest.mock import Mock, patch, MagicMock
def test_mock(): mock_obj = Mock() mock_obj.method.return_value = "mocked"
result = mock_obj.method() assert result == "mocked"
mock_obj.method.assert_called_once() mock_obj.method.assert_called_with()
|
patch 替换函数
1 2 3 4 5 6 7 8 9
| @patch('requests.get') def test_api_call(mock_get): mock_get.return_value = Mock(status_code=200, json=lambda: {"key": "value"})
result = call_api()
assert result == {"key": "value"} mock_get.assert_called_once_with("http://api.example.com")
|
patch 的路径问题
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
|
@patch('module.func') def test_correct(mock_func): mock_func.return_value = 42 result = func() assert result == 42
@patch('other_module.func') def test_wrong(): pass
|
spy 记录调用但不替换
1 2 3 4 5 6 7 8 9
| from unittest.mock import call
@patch('module.func', wraps=module.func) def test_spy(mock_func): result = func(1, 2)
assert result == 3 mock_func.assert_called_once_with(1, 2)
|
Mock 常用配置
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| mock_obj.method.return_value = 42 mock_obj.method.side_effect = Exception("error") mock_obj.method.side_effect = lambda x: x * 2
mock_obj.configure_mock(name="mocked_name")
mock_obj.method.return_value = 42 mock_obj.other_method.return_value = "other"
mock_obj.get_config.return_value.get_metrics.return_value = [1, 2, 3]
|
MagicMock 自动创建嵌套
1 2 3 4 5 6 7
| mock_obj = MagicMock()
mock_obj.get_config.return_value.metrics = [1, 2, 3]
mock_obj.config.metrics.get()
|
实战:重构后测试保障
项目结构
1 2 3 4 5 6 7 8
| src/ ├── user_service.py ├── user_repository.py └── email_service.py tests/ ├── conftest.py ├── test_user_service.py └── test_user_repository.py
|
conftest.py(共享 fixture)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| import pytest from unittest.mock import MagicMock
@pytest.fixture def mock_user_repository(): return MagicMock()
@pytest.fixture def mock_email_service(): service = MagicMock() service.send_welcome_email.return_value = True return service
@pytest.fixture def user_service(mock_user_repository, mock_email_service): from src.user_service import UserService return UserService(mock_user_repository, mock_email_service)
|
测试业务逻辑
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
| from unittest.mock import Mock, patch import pytest from src.user_service import UserNotFoundError, UserValidationError
class TestUserService:
def test_get_user_success(self, user_service, mock_user_repository): mock_user_repository.get_by_id.return_value = { "id": 1, "name": "Alice", "email": "alice@example.com" }
result = user_service.get_user(1)
assert result["name"] == "Alice" mock_user_repository.get_by_id.assert_called_once_with(1)
def test_get_user_not_found(self, user_service, mock_user_repository): mock_user_repository.get_by_id.return_value = None
with pytest.raises(UserNotFoundError): user_service.get_user(999)
def test_create_user_success(self, user_service, mock_user_repository): mock_user_repository.create.return_value = {"id": 1, "name": "Bob"}
result = user_service.create_user(name="Bob", email="bob@example.com")
assert result["id"] == 1 mock_user_repository.create.assert_called_once() self.user_service_mock.email_service.send_welcome_email.assert_called_once()
def test_create_user_invalid_email(self, user_service): with pytest.raises(UserValidationError, match="Invalid email"): user_service.create_user(name="Bob", email="invalid-email")
|
参数化 + fixture
1 2 3 4 5 6 7 8 9
| @pytest.mark.parametrize("email,expected", [ ("valid@example.com", True), ("another@test.cn", True), ("no-at-sign.com", False), ("missing@domain", False), ("", False), ]) def test_email_validation(user_service, email, expected): assert user_service.validate_email(email) == expected
|
覆盖率报告
生成覆盖率报告
1 2 3 4 5 6 7 8 9 10 11
| pip install pytest-cov
pytest --cov=src --cov-report=html tests/
open htmlcov/index.html
pytest --cov=src --cov-report=term-missing tests/
|
覆盖率输出示例
1 2 3 4 5 6 7
| ---------- coverage: platform darwin, Python 3.11.0 ---------- Name Stmts Miss Cover Missing ------------------------------------------------------------ src/user_service.py 50 5 90% 32,45 src/email_service.py 20 20 0% all ------------------------------------------------------------ TOTAL 70 25 64%
|
排除代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| [run] source = src omit = */tests/* */migrations/* */__init__.py
[report] exclude_lines = pragma: no cover def __repr__ raise NotImplementedError if __name__ == .__main__.:
|
总结
pytest 核心功能:
| 功能 |
用途 |
示例 |
| fixture |
提供测试数据和 setup/teardown |
@pytest.fixture |
| parametrize |
用不同参数运行同一测试 |
@pytest.mark.parametrize |
| mock |
隔离外部依赖 |
unittest.mock.Mock |
| patch |
替换函数/对象 |
@patch |
测试分层建议:
- 单元测试:测试单个函数/类的逻辑
- 集成测试:测试组件间的交互
- 端到端测试:测试完整的业务流程
好的测试习惯:
- 每个测试函数只测一个功能点
- 测试之间相互独立
- 给测试起有意义的名字
- 及时更新测试以反映代码变化