目录
- 1. 背景
- 2. dataclass函数签名详解
- 2.1 repr示例
- 2.2 eq与order示例
- 2.3 frozen示例
- 2.4 `__post_init__`
- 2.5 继承
- 3. Field
- 3.1 default与default_factory
- 3.2 init与repr
- 3.3 compare
- 3.4 metadata
- Ref
1. 背景
考虑这样一个场景。假如我们要定义一个 Person
类,并希望它具有姓名、性别、年龄、身高、体重这五个属性,则可以这样写:
class Person:def __init__(self, name: str, age: int, sex: str, height: float, weight: float):self.name = nameself.age = ageself.sex = sexself.height = heightself.weight = weight
如果我们还希望 print(person)
的时候能够打印出这个人的所有信息,并且支持比较两个人是否相等,则应当添加 __repr__
和 __eq__
方法:
class Person:def __init__(self, name: str, age: int, sex: str, height: float, weight: float):self.name = nameself.age = ageself.sex = sexself.height = heightself.weight = weightdef __repr__(self):return (f"Person(name={self.name!r}, age={self.age}, sex={self.sex!r}, "f"height={self.height}, weight={self.weight})")def __eq__(self, other):if isinstance(other, Person):return (self.name == other.name and self.age == other.age andself.sex == other.sex and self.height == other.height andself.weight == other.weight)return False
可以看出代码有些许复杂,如果使用 dataclass
,代码将得到大大简化:
from dataclasses import dataclass@dataclass
class Person:name: strage: intsex: strheight: floatweight: float
在使用了 @dataclass
装饰器后,我们只需要声明每个属性及对应的类型,dataclass
会自动为该类生成 __init__
、__repr__
和 __eq__
方法,生成的 __eq__
方法会比较所有的属性。
有了以上背景知识后,我们再来看一下 dataclasses
这个库到底是干什么的:
📝
dataclasses
是 Python 3.7 及更高版本中引入的一个标准库,用于简化类的定义。它旨在通过使用装饰器和类型注解来减少样板代码,特别是在创建一个用于存储数据的类的情况下。dataclasses
提供了一个装饰器和一系列支持函数,让你能够快速定义保存数据的类,同时自动添加特殊方法,比如__init__()
、__repr__()
和__eq__()
等。
2. dataclass函数签名详解
dataclass
的函数签名如下:
def dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False,unsafe_hash=False, frozen=False, match_args=True,kw_only=False, slots=False, weakref_slot=False):
⚠️
/
之前的所有参数,都必须以位置形式传参,不得以关键字形式传参。*
之后的所有参数,都必须以关键字形式传参,不得以位置形式传参。/
和*
之间的所有参数,既可以以位置形式传参,也可以以关键字形式传参。该特性由Python 3.8版本引入。
cls
是将要被装饰的类。
init
:决定是否为类添加__init__
方法。repr
:决定是否为类添加__repr__
方法。eq
:决定是否为类添加__eq__
方法。order
:决定是否为类添加__lt__
、__le__
、__gt__
、__ge__
方法。frozen
:决定是否冻结实例化后的属性。一旦冻结,任何企图修改对象属性的行为都会引发FrozenInstanceError
。
⚠️
dataclass
为cls
添加方法时实际上是通过setattr
函数来实现(例如setattr(cls, "__repr__", _repr_fn)
)。在setattr(object, name, value)
中,object
可以是类对象,实例对象,模块对象,或是实现了__setattr__
方法的对象。name
只能是字符串。value
可以是基本数据类型,容器类型,类的实例,函数等。
2.1 repr示例
@dataclass
class Person:name: strage: intsex: strheight: floatweight: floatperson = Person(name="John Doe", age=30, sex="Male", height=1.75, weight=70)print(person)
# Person(name='John Doe', age=30, sex='Male', height=1.75, weight=70)
可以看到 person
的所有信息都被打印出来了,并且打印的格式和实例化的格式几乎相同。
⚠️
name: str
仅仅是一个类型注解,我们完全可以在实例化的时候向name
传入一个整数。
我们还可以预设默认值:
@dataclass
class Person:name: str = "John Doe"age: int = 30sex: str = "Male"height: float = 1.75weight: float = 70person = Person()print(person)
# 结果同上
注意,类型注解不是必须的,但此时必须要给相应的变量赋值:
@dataclass
class Person:name = "John Doe"age: intsex: strheight = 1.75weight # 这一行会报错
2.2 eq与order示例
注意到 eq
是 order
的一个简化情形,这里我们只讨论 order
。dataclass
是如何比较两个类的大小呢?
事实上,在比较类的大小时,dataclass
会将类的所有属性按照其声明顺序打包成一个元组,然后比较两个元组的大小。以上述的 Person
为例,当设置 order=True
时,其生成的 __le__
方法等价于:
def __le__(self, other):return ((self.name, self.age, self.sex, self.height, self.weight) <=(other.name, other.age, other.sex, other.height, other.weight))
由于 name
是字符串,则会按照字典序进行比较,如果 self.name <= other.name
,则 self <= other
。如果 self.name == other.name
,那么就会比较下一项 age
,以此类推。
一些比较的例子:
@dataclass(order=True)
class Person:name: strage: intsex: strheight: floatweight: floatp1 = Person("Alice", 30, "Female", 1.65, 60)
p2 = Person("Bob", 25, "Male", 1.75, 80)
p3 = Person("Alice", 35, "Female", 1.65, 60)
p4 = Person("Alice", 30, "Female", 1.70, 65)
p5 = Person("Charlie", 30, "Male", 1.80, 90)
p6 = Person("Bob", 25, "Male", 1.75, 75)
p7 = Person("Alice", 30, "Male", 1.65, 60)
p8 = Person("Alice", 30, "Female", 1.65, 55)comparisons = [(p1, p2),(p1, p3),(p1, p4),(p1, p5),(p2, p6),(p1, p7),(p1, p8),
]results = [(p1 <= p2) for p1, p2 in comparisons]
print(results)
# [True, True, True, True, False, True, False]
2.3 frozen示例
在 dataclass
中设置 frozen=True
会使得生成的类实例是不可变的。这意味着一旦一个实例被创建,它的任何属性都不能被改变。这类似于一个只读对象,可以确保实例在创建后保持不变,有助于确保数据的一致性和线程安全性。
from dataclasses import dataclass@dataclass(frozen=True)
class Person:name: strage: intsex: strheight: floatweight: floatimmutable_person = Person("Jane", 28, "Female", 1.68, 58)
immutable_person.age = 29
# dataclasses.FrozenInstanceError: cannot assign to field 'age'
2.4 __post_init__
__post_init__
方法是 dataclasses
中一个非常有用的特性,允许你在一个类的初始化后立即运行一些额外的代码。当你使用 @dataclass
装饰器来装饰一个类时,这个类会自动获得一个 __init__
方法,它会根据类中定义的属性来初始化实例。然而,如果你需要在实例化后立即执行一些操作(比如属性验证、转换、或基于其他属性计算一个属性的值),__post_init__
方法就显得非常有用。
📝
__post_init__
是dataclasses
模块特有的一个方法,普通的类没有这个方法。
考虑这样一个场景,如果一个类只有两个属性 a
和 b
(均为 int
型),a
由人工输入决定,b
的值是 a
的值的两倍,如果使用普通类,我们可以这样写:
class MyClass:def __init__(self, a: int):self.a = aself.b = a * 2
如果使用 dataclass
,我们的第一反应可能是下面这样:
@dataclass
class MyClass:a: intb: int = a * 2
但这样会出现报错:NameError: name 'a' is not defined
,因为这种做法实际上相当于:
class MyClass:def __init__(self, a: int, b: int = a * 2):self.a = aself.b = b
我们可以添加 __post_init__
方法,它会在 __init__
之后立刻执行:
@dataclass
class MyClass:a: intdef __post_init__(self):self.b = self.a * 2myclass = MyClass(a=2)
print(myclass.a)
print(myclass.b) # 4
除了计算属性之外, __post_init__
还可以用来校验属性,如下是一个较为复杂的例子:
from dataclasses import dataclass, field
import redef validate_email(email):pattern = r"^\w+([\.-]?\w+)*@\w+([\.-]?\w+)*(\.\w{2,3})+$"if not re.match(pattern, email):raise ValueError(f"无效的电子邮件地址: {email}")def validate_phone(phone):pattern = r"^\+?[0-9]{10,15}$"if not re.match(pattern, phone):raise ValueError(f"无效的电话号码: {phone}")@dataclass
class Employee:name: strage: intemail: strphone: strdepartment: strsalary: floatdef __post_init__(self):if len(self.name) < 3:raise ValueError("姓名长度必须至少为3个字符")if not (18 <= self.age <= 65):raise ValueError("年龄必须在18到65之间")validate_email(self.email)validate_phone(self.phone)if len(self.department) == 0:raise ValueError("部门不能为空")if self.salary < 0:raise ValueError("薪资不能为负数")
2.5 继承
dataclass
同样支持继承,子类将拥有父类的所有属性和方法:
@dataclass
class Person:name: strage: intdef greet(self):return f"Hello, my name is {self.name} and I am {self.age} years old."@dataclass
class Employee(Person):salary: floatdef show_salary(self):return f"My salary is {self.salary}."emp = Employee(name="John Doe", age=30, salary=50000)print(emp.greet())
print(emp.show_salary())
子类还可以重写父类的属性和方法:
@dataclass
class Person:name: str = 'None'age: int = -1def introduce(self):return f"Hello, my name is {self.name} and I am {self.age} years old."@dataclass
class Employee(Person):salary: float = 50000name: str = 'John Doe'age: int = 30def introduce(self):base_introduction = super().introduce()return f"{base_introduction}\nMy salary is {self.salary}."emp = Employee()
print(emp.introduce())
3. Field
📝 被
@dataclass
装饰的类的属性又叫做字段(field)。
注意,以上的解释并不意味着一个被 @dataclass
装饰的类的属性就是 Field
的实例:
@dataclass
class MyClass:a: int = 2myclass = MyClass()
print(type(myclass.a))
# <class 'int'>
引入 Field
的目的是为了更灵活、更精细化地控制数据类的字段,Field
的构造函数如下:
class Field:def __init__(self, default, default_factory, init, repr, hash, compare, metadata):
3.1 default与default_factory
⚠️
default
和default_factory
不能同时指定。
default
参数用来指定字段的默认值。如果字段未在实例化时提供值,则会使用此默认值。
from dataclasses import dataclass, field@dataclass
class A:a: int = field(default=2)A1 = A()
A2 = A(3)
print(A1.a, A2.a)
# 2 3
如果要设置某一个字段的默认值为可变类型(例如列表、字典等),那么所有实例将共享这同一个默认值对象(参考博客)。这可能会导致意想不到的行为,因为修改任何一个实例的字段将影响所有实例。例如:
@dataclass
class A:a: List[int] = []
将会触发报错
ValueError: mutable default <class 'list'> for field a is not allowed: use default_factory
使用 default_factory
时,我们需要为它赋值一个无参数的可调用对象。每次创建数据类的实例时,都会调用这个可调用对象来生成该字段的默认值。面对以上报错,我们可以使用 default_factory
来解决:
@dataclass
class A:a: List[int] = field(default_factory=list) # 默认值是空列表b: List[int] = field(default_factory=lambda: [1, 2, 3]) # 默认值是[1, 2, 3]
3.2 init与repr
init
参数用于指定一个字段是否应该包含在自动生成的 __init__
方法中,请看下面的例子:
@dataclass
class A:a: intb: int = field(init=False)print(inspect.signature(A.__init__))
# (self, a: int) -> None
通过函数签名可以看出,A
自动生成的构造函数里,并没有 b
这个形参,如果我们执行 a = A(1, 2)
这样的实例化,则会报错:
TypeError: __init__() takes 2 positional arguments but 3 were given
设置 init=True
后的函数签名变成:(self, a: int, b: int) -> None
。
使用场景: 如果某个字段是基于其他字段的值计算得出的,我们可以将该字段的 init
参数设置为 False
(非必须,参考2.4节,但为了可读性最好还是声明一下),然后在 __post_init__
中计算。
repr
参数控制一个字段是否应该包含在自动生成的 __repr__
方法的返回值中,请看下面的例子:
@dataclass
class A:a: intb: int@dataclass
class B:a: intb: int = field(repr=False)a = A(1, 2)
b = B(1, 2)print(a) # A(a=1, b=2)
print(b) # B(a=1)
使用场景: 如果类中某些字段是辅助性质的,或者可能包含大量数据,不适合在每次打印对象时都显示,则可以通过将这些字段的 repr
参数设置为 False
来排除它们。
3.3 compare
compare
用来决定是否在比较方法中(如 __eq__
, __lt__
, __le__
, __gt__
, 和 __ge__
)包含该字段。
from dataclasses import dataclass, field@dataclass
class Book:title: strauthor: stryear: int = field(compare=False)book1 = Book("The Great Gatsby", "F. Scott Fitzgerald", 1925)
book2 = Book("The Great Gatsby", "F. Scott Fitzgerald", 2020)print(book1 == book2) # True
3.4 metadata
metadata
参数用来为字段附加额外信息,这些信息不影响字段的行为,但可以被你的程序或第三方库用于各种目的,例如验证、序列化或其他自定义处理。
metadata
通常是一个字典:
@dataclass
class Product:name: strprice: float = field(default=0.0, metadata={'unit': 'USD'})in_stock: bool = field(default=True, metadata={'description': 'Whether the product is in stock'})
HuggingFace的 TrainingArguments
实际上就是典型的数据类,如下展示了它的前几个字段:
@dataclass
class TrainingArguments:framework = "pt"output_dir: str = field(metadata={"help": "The output directory where the model predictions and checkpoints will be written."},)overwrite_output_dir: bool = field(default=False,metadata={"help": ("Overwrite the content of the output directory. ""Use this to continue training if output_dir points to a checkpoint directory.")},)do_train: bool = field(default=False, metadata={"help": "Whether to run training."})do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
Ref
[1] https://zhuanlan.zhihu.com/p/59657729
[2] https://blog.csdn.net/be5yond/article/details/119545119
[3] https://zhuanlan.zhihu.com/p/61553610