Python 单元测试 pytest 实战:fixture、parametrize、mock

单元测试是保证代码质量的重要手段。pytest 是 Python 最流行的测试框架,简洁强大。本文介绍 pytest 的核心功能:fixture、parametrize 和 mock,以及如何写出可维护的测试代码。

pytest 基础用法

安装与运行

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
pip install pytest pytest-cov

# 运行所有测试
pytest

# 运行指定文件
pytest tests/test_user.py

# 运行指定测试函数
pytest tests/test_user.py::test_get_user

# 查看详细输出
pytest -v

# 停在第一个失败
pytest -x

断言

pytest 内置断言,失败时会显示详细信息:

1
2
3
4
5
6
7
8
9
10
11
def test_basic():
assert 1 + 1 == 2
assert "hello".upper() == "HELLO"
assert [1, 2, 3] == [1, 2, 3]

# 断言异常
with pytest.raises(ValueError):
int("not a number")

# 浮点数比较
assert 0.1 + 0.2 == pytest.approx(0.3)

fixture 作用域与共享

fixture 是 pytest 最强大的功能之一,用于提供测试数据、setup/teardown。

基本用法

1
2
3
4
5
6
7
8
9
10
import pytest

# fixture 定义
@pytest.fixture
def user():
return {"name": "Alice", "age": 25}

# 使用 fixture
def test_user_name(user):
assert user["name"] == "Alice"

scope 作用域

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# function:每个函数执行一次(默认)
@pytest.fixture(scope="function")
def db_connection():
conn = create_db_connection()
yield conn
conn.close() # teardown

# class:每个类执行一次
@pytest.fixture(scope="class")
def db_connection():
conn = create_db_connection()
yield conn
conn.close()

# module:每个模块执行一次
@pytest.fixture(scope="module")
def config():
return load_config()

# session:整个测试会话执行一次
@pytest.fixture(scope="session")
def browser():
driver = launch_browser()
yield driver
driver.quit()

fixture 依赖

1
2
3
4
5
6
7
8
9
10
11
12
@pytest.fixture
def db_connection():
return create_connection()

@pytest.fixture
def user_repository(db_connection): # 依赖 db_connection
return UserRepository(db_connection)

def test_save_user(user_repository):
user = User(name="Alice")
user_repository.save(user)
assert user.id is not None

autouse 自动执行

1
2
3
4
5
6
@pytest.fixture(autouse=True)
def setup_teardown():
# 每个测试函数都会自动执行
print("\nsetup")
yield
print("\nteardown")

fixture 返回字典/对象

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# 方式1:返回字典
@pytest.fixture
def user_dict():
return {"name": "Alice", "age": 25}

# 方式2:使用 dataclass
from dataclasses import dataclass

@dataclass
class User:
name: str
age: int

@pytest.fixture
def user_obj():
return User(name="Alice", age=25)

# 方式3:使用工厂函数
@pytest.fixture
def user_factory():
def create_user(name="Bob", age=20):
return User(name=name, age=age)
return create_user

def test_create_user(user_factory):
user = user_factory(name="Charlie", age=30)
assert user.name == "Charlie"

parametrize 参数化

parametrize 让你用不同参数运行同一测试。

基本用法

1
2
3
4
5
6
7
8
@pytest.mark.parametrize("input,expected", [
(1, 1),
(2, 4),
(3, 9),
(4, 16),
])
def test_square(input, expected):
assert input ** 2 == expected

多参数

1
2
3
4
5
6
7
8
@pytest.mark.parametrize("a,b,expected", [
(1, 2, 3),
(0, 0, 0),
(-1, 1, 0),
(100, 200, 300),
])
def test_add(a, b, expected):
assert a + b == expected

组合参数

1
2
3
4
5
@pytest.mark.parametrize("x", [1, 2, 3])
@pytest.mark.parametrize("y", [10, 20])
def test_combinations(x, y):
# 会运行 3 * 2 = 6 次
assert x + y == x + y

字符串参数

1
2
3
4
5
6
7
@pytest.mark.parametrize("email", [
"user@example.com",
"test@test.cn",
"admin@company.org",
])
def test_email_validation(email):
assert validate_email(email) is True

标记失败测试

1
2
3
4
5
6
7
8
@pytest.mark.parametrize("input,expected", [
("2 + 2", 4),
("3 * 3", 9),
pytest.param("1 / 0", 0, marks=pytest.mark.xfail), # 预期失败
pytest.param("0 / 0", 0, marks=pytest.mark.skip("暂不支持")),
])
def test_evaluate(input, expected):
assert evaluate(input) == expected

mock 与 patch

mock 用于隔离外部依赖,模拟不想实际调用的函数或对象。

unittest.mock 基本用法

1
2
3
4
5
6
7
8
9
10
11
12
13
from unittest.mock import Mock, patch, MagicMock

# Mock 对象
def test_mock():
mock_obj = Mock()
mock_obj.method.return_value = "mocked"

result = mock_obj.method()
assert result == "mocked"

# 验证调用
mock_obj.method.assert_called_once()
mock_obj.method.assert_called_with()

patch 替换函数

1
2
3
4
5
6
7
8
9
# 替换一个函数
@patch('requests.get')
def test_api_call(mock_get):
mock_get.return_value = Mock(status_code=200, json=lambda: {"key": "value"})

result = call_api()

assert result == {"key": "value"}
mock_get.assert_called_once_with("http://api.example.com")

patch 的路径问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 假设代码是:
# from module import func
# func()

# mock 时要 mock 函数被导入的地方
@patch('module.func') # 正确
def test_correct(mock_func):
mock_func.return_value = 42
result = func()
assert result == 42

# 错误写法
@patch('other_module.func') # 错误
def test_wrong():
pass

spy 记录调用但不替换

1
2
3
4
5
6
7
8
9
from unittest.mock import call

@patch('module.func', wraps=module.func)
def test_spy(mock_func):
# func 仍然执行,但调用被记录
result = func(1, 2)

assert result == 3
mock_func.assert_called_once_with(1, 2)

Mock 常用配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 配置返回值
mock_obj.method.return_value = 42
mock_obj.method.side_effect = Exception("error") # 抛出异常
mock_obj.method.side_effect = lambda x: x * 2 # 函数式

# 配置属性
mock_obj.configure_mock(name="mocked_name")

# 部分 mock
mock_obj.method.return_value = 42
mock_obj.other_method.return_value = "other"

# 链式调用
mock_obj.get_config.return_value.get_metrics.return_value = [1, 2, 3]

MagicMock 自动创建嵌套

1
2
3
4
5
6
7
# MagicMock 会自动创建不存在的属性和方法
mock_obj = MagicMock()

mock_obj.get_config.return_value.metrics = [1, 2, 3]

# 不需要手动配置
mock_obj.config.metrics.get() # MagicMock

实战:重构后测试保障

项目结构

1
2
3
4
5
6
7
8
src/
├── user_service.py
├── user_repository.py
└── email_service.py
tests/
├── conftest.py
├── test_user_service.py
└── test_user_repository.py

conftest.py(共享 fixture)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import pytest
from unittest.mock import MagicMock

@pytest.fixture
def mock_user_repository():
return MagicMock()

@pytest.fixture
def mock_email_service():
service = MagicMock()
service.send_welcome_email.return_value = True
return service

@pytest.fixture
def user_service(mock_user_repository, mock_email_service):
from src.user_service import UserService
return UserService(mock_user_repository, mock_email_service)

测试业务逻辑

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from unittest.mock import Mock, patch
import pytest
from src.user_service import UserNotFoundError, UserValidationError

class TestUserService:

def test_get_user_success(self, user_service, mock_user_repository):
# Arrange
mock_user_repository.get_by_id.return_value = {
"id": 1, "name": "Alice", "email": "alice@example.com"
}

# Act
result = user_service.get_user(1)

# Assert
assert result["name"] == "Alice"
mock_user_repository.get_by_id.assert_called_once_with(1)

def test_get_user_not_found(self, user_service, mock_user_repository):
mock_user_repository.get_by_id.return_value = None

with pytest.raises(UserNotFoundError):
user_service.get_user(999)

def test_create_user_success(self, user_service, mock_user_repository):
mock_user_repository.create.return_value = {"id": 1, "name": "Bob"}

result = user_service.create_user(name="Bob", email="bob@example.com")

assert result["id"] == 1
mock_user_repository.create.assert_called_once()
# 验证发送邮件
self.user_service_mock.email_service.send_welcome_email.assert_called_once()

def test_create_user_invalid_email(self, user_service):
with pytest.raises(UserValidationError, match="Invalid email"):
user_service.create_user(name="Bob", email="invalid-email")

参数化 + fixture

1
2
3
4
5
6
7
8
9
@pytest.mark.parametrize("email,expected", [
("valid@example.com", True),
("another@test.cn", True),
("no-at-sign.com", False),
("missing@domain", False),
("", False),
])
def test_email_validation(user_service, email, expected):
assert user_service.validate_email(email) == expected

覆盖率报告

生成覆盖率报告

1
2
3
4
5
6
7
8
9
10
11
# 安装
pip install pytest-cov

# 运行并生成报告
pytest --cov=src --cov-report=html tests/

# 查看 HTML 报告
open htmlcov/index.html

# 覆盖率统计
pytest --cov=src --cov-report=term-missing tests/

覆盖率输出示例

1
2
3
4
5
6
7
---------- coverage: platform darwin, Python 3.11.0 ----------
Name Stmts Miss Cover Missing
------------------------------------------------------------
src/user_service.py 50 5 90% 32,45
src/email_service.py 20 20 0% all
------------------------------------------------------------
TOTAL 70 25 64%

排除代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# .coveragerc
[run]
source = src
omit =
*/tests/*
*/migrations/*
*/__init__.py

[report]
exclude_lines =
pragma: no cover
def __repr__
raise NotImplementedError
if __name__ == .__main__.:

总结

pytest 核心功能:

功能 用途 示例
fixture 提供测试数据和 setup/teardown @pytest.fixture
parametrize 用不同参数运行同一测试 @pytest.mark.parametrize
mock 隔离外部依赖 unittest.mock.Mock
patch 替换函数/对象 @patch

测试分层建议:

  1. 单元测试:测试单个函数/类的逻辑
  2. 集成测试:测试组件间的交互
  3. 端到端测试:测试完整的业务流程

好的测试习惯:

  • 每个测试函数只测一个功能点
  • 测试之间相互独立
  • 给测试起有意义的名字
  • 及时更新测试以反映代码变化