Python 装饰器高级用法:类装饰器、带参数的装饰器、装饰器链

在之前的博客中,我整理了 Python 装饰器的基础知识。这篇文章进一步探讨装饰器的高级用法,包括类装饰器、带参数的装饰器以及装饰器链的实现。

函数装饰器回顾

先简单回顾函数装饰器的基本形态:

1
2
3
4
5
6
7
8
9
10
11
def my_decorator(func):
def wrapper(*args, **kwargs):
print("调用前")
result = func(*args, **kwargs)
print("调用后")
return result
return wrapper

@my_decorator
def say_hello(name):
print(f"Hello, {name}")

这个模式很清晰:my_decorator 接收一个函数,返回一个包装后的函数。

但实际开发中,我们经常遇到更复杂的需求。

类装饰器

装饰器不一定非得是函数,也可以是实现了 __call__ 方法的类

基本用法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class CountCalls:
"""统计函数调用次数的装饰器"""

def __init__(self, func):
self.func = func
self.count = 0

def __call__(self, *args, **kwargs):
self.count += 1
print(f"函数被调用了 {self.count} 次")
return self.func(*args, **kwargs)

@CountCalls
def say_hello(name):
print(f"Hello, {name}")

say_hello("Alice") # 函数被调用了 1 次\nHello, Alice
say_hello("Bob") # 函数被调用了 2 次\nHello, Bob

@CountCalls 装饰在函数上时,CountCalls__init__ 会被调用,传入被装饰的函数。当函数被调用时,实际上是调用了 CountCalls__call__ 方法。

保存函数元信息

类装饰器配合 functools.wraps 使用需要注意:

1
2
3
4
5
6
7
8
9
10
import functools

class Logging:
def __init__(self, func):
self.func = func
functools.update_wrapper(self, func) # 复制元信息到类实例

def __call__(self, *args, **kwargs):
print(f"调用 {self.func.__name__}")
return self.func(*args, **kwargs)

注意这里需要手动调用 update_wrapper,因为 wraps 是为嵌套函数设计的,对类的支持需要额外处理。

带参数的装饰器

很多时候,我们需要给装饰器传递参数,比如指定日志级别、缓存超时时间等。

装饰器工厂函数

解决方案是再包装一层

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def repeat(times=1):
"""重复执行函数指定次数"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
results = []
for _ in range(times):
results.append(func(*args, **kwargs))
return results
return wrapper
return decorator

@repeat(times=3)
def say_hello():
print("Hello")
return "done"

say_hello() # 打印三次 "Hello",返回 ["done", "done", "done"]

执行流程:

  1. repeat(times=3) 调用,返回 decorator 函数
  2. decorator(say_hello) 调用,返回 wrapper 函数
  3. say_hello() 实际调用 wrapper()

带参数的类装饰器

同样的思路用于类装饰器:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class Retry:
"""失败时重试指定次数"""

def __init__(self, times=3):
self.times = times

def __call__(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
for i in range(self.times):
try:
return func(*args, **kwargs)
except Exception as e:
if i == self.times - 1:
raise
print(f"第 {i+1} 次失败,重试中...")
return None
return wrapper

@Retry(times=3)
def fetch_data():
# 可能失败的逻辑
pass

装饰器链的执行顺序

一个函数可以同时被多个装饰器修饰,执行顺序是从近到远,从上到下

1
2
3
4
5
@a
@b
@c
def func():
pass

等价于:

1
func = a(b(c(func)))

执行顺序:

  1. 先执行 @c):用 c装饰func`
  2. 再执行 @b:用 b 装饰结果
  3. 最后执行 @a:用 a 装饰结果

调用 func() 时,执行顺序是从外到内。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import functools

def a(func):
print("a 开始")
@functools.wraps(func)
def wrapper(*args, **kwargs):
print("a 执行")
return func(*args, **kwargs)
print("a 结束")
return wrapper

def b(func):
print("b 开始")
@functools.wraps(func)
def wrapper(*args, **kwargs):
print("b 执行")
return func(*args, **kwargs)
print("b 结束")
return wrapper

@a
@b
def test():
print("test 执行")

# 输出顺序:
# b 开始
# b 结束
# a 开始
# a 结束
# 调用 test() 时:a 执行 -> b 执行 -> test 执行

functools.wraps 的作用

functools.wraps 是个容易被忽视但很重要的工具:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import functools

def my_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper

@my_decorator
def original():
"""这是 original 的文档"""
pass

print(original.__name__) # 'original',而不是 'wrapper'
print(original.__doc__) # '这是 original 的文档'

wraps 将原函数的 __name____doc____module____annotations__ 等属性复制到包装函数。如果没有这步,装饰后的函数会丢失所有元信息,对调试和文档生成影响很大。

实战:日志装饰器

结合以上知识,实现一个功能完整的日志装饰器:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import functools
import time
import logging

logger = logging.getLogger(__name__)

def log(level='INFO', include_args=True, include_result=False):
"""灵活的日志装饰器

Args:
level: 日志级别
include_args: 是否记录参数
include_result: 是否记录返回值
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
log_func = getattr(logger, level.lower())
start_time = time.time()

# 记录调用信息
msg = f"调用 {func.__name__}"
if include_args:
sig = str(args) + str(kwargs)
msg += f",参数: {sig}"
log_func(msg)

try:
result = func(*args, **kwargs)

# 记录执行结果
elapsed = time.time() - start_time
success_msg = f"{func.__name__} 执行成功,耗时 {elapsed:.3f}s"
if include_result:
success_msg += f",返回值: {result}"
log_func(success_msg)

return result
except Exception as e:
elapsed = time.time() - start_time
logger.error(f"{func.__name__} 执行失败,耗时 {elapsed:.3f}s,异常: {e}")
raise

return wrapper
return decorator

# 使用示例
@log(level='DEBUG', include_args=True, include_result=False)
def calculate(a, b):
return a + b

@log(level='INFO')
def process_data(data):
# 业务逻辑
pass

实战:性能统计装饰器

另一个实用场景是性能统计:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import functools
import time
from collections import defaultdict

class PerformanceStats:
"""函数性能统计(类装饰器,支持多函数统计)"""

_stats = defaultdict(lambda: {'count': 0, 'total_time': 0, 'errors': 0})

def __init__(self, func):
self.func = func
functools.update_wrapper(self, func)

def __call__(self, *args, **kwargs):
name = self.func.__name__
start = time.perf_counter()

try:
result = self.func(*args, **kwargs)
elapsed = time.perf_counter() - start

self._stats[name]['count'] += 1
self._stats[name]['total_time'] += elapsed
return result
except Exception:
self._stats[name]['errors'] += 1
raise

@classmethod
def report(cls):
"""输出统计报告"""
print("\n=== 性能统计报告 ===")
print(f"{'函数名':<20} {'调用次数':<10} {'总耗时(s)':<12} {'平均耗时(ms)':<15} {'错误数':<8}")
print("-" * 70)
for name, stats in cls._stats.items():
avg = (stats['total_time'] / stats['count'] * 1000) if stats['count'] > 0 else 0
print(f"{name:<20} {stats['count']:<10} {stats['total_time']:<12.4f} {avg:<15.2f} {stats['errors']:<8}")

@PerformanceStats
def slow_operation():
time.sleep(0.1)

@PerformanceStats
def fast_operation():
pass

# 测试
for _ in range(5):
slow_operation()

for _ in range(10):
fast_operation()

PerformanceStats.report()

输出:

1
2
3
4
5
=== 性能统计报告 ===
函数名 调用次数 总耗时(s) 平均耗时(ms) 错误数
----------------------------------------------------------------------
slow_operation 5 0.5005 100.10 0
fast_operation 10 0.0000 0.02 0

总结

装饰器高级用法的核心要点:

  1. 类装饰器:实现 __call__ 方法,保存状态更自然
  2. 带参数的装饰器:返回装饰器的工厂函数
  3. 装饰器链:从近到远执行,最终效果从外到内
  4. **functools.wraps**:保留原函数的元信息
  5. 实际应用:日志、性能统计、重试、缓存、权限校验等

装饰器是 Python 最强大的特性之一,掌握这些高级用法能让代码更加优雅和可复用。