類和接口
Python作為面向對象語言,繼承多態和封裝有良好的應用,如何編寫可維護的代碼呢?
- Item37: 組合類而不是嵌套多層的Built-in類型
假設現在要記錄一群學生(不知道姓名)的分數。我可以定義一個類來把姓名存儲為字典。
class SimpleGradebook:
def __init__(self):
self._grades = {}
def add_student(self, name):
self._grades[name] = []
def report_grade(self, name, score):
self._grades[name].append(score)
def average_grade(self, name):
grades = self._grades[name]
return sum(grades) / len(grades)
book = SimpleGradebook()
book.add_student('Isaac Newton')
book.report_grade('Isaac Newton', 90)
book.report_grade('Isaac Newton', 95)
book.report_grade('Isaac Newton', 85)
print(book.average_grade('Isaac Newton'))
>>>
90.0
字典及相關的built-in類型容易用,但是有過度擴展的危險。比如現在不止想保存分數,還想保存對應科目:
from collections import defaultdict
class BySubjectGradebook:
def __init__(self):
self._grades = {} # Outer dict
def add_student(self, name):
self._grades[name] = defaultdict(list) # Inner dict
這足夠直接且符合直覺,多層的字典似乎也還能管理。繼續修改對應的代碼:
def report_grade(self, name, subject, grade):
by_subject = self._grades[name]
grade_list = by_subject[subject]
grade_list.append(grade)
def average_grade(self, name):
by_subject = self._grades[name]
total, count = 0, 0
for grades in by_subject.values():
total += sum(grades)
count += len(grades)
return total / count
book = BySubjectGradebook()
book.add_student('Albert Einstein')
book.report_grade('Albert Einstein', 'Math', 75)
book.report_grade('Albert Einstein', 'Math', 65)
book.report_grade('Albert Einstein', 'Gym', 90)
book.report_grade('Albert Einstein', 'Gym', 95)
print(book.average_grade('Albert Einstein'))
>>>
81.25
假如現在又有新的需求,需要變為不同測試帶有不同的權重:(不止是分數,還有權重)
class WeightedGradebook:
def __init__(self):
self._grades = {}
def add_student(self, name):
self._grades[name] = defaultdict(list)
def report_grade(self, name, subject, score, weight):
by_subject = self._grades[name]
grade_list = by_subject[subject]
grade_list.append((score, weight))
def average_grade(self, name):
by_subject = self._grades[name]
score_sum, score_count = 0, 0
for subject, scores in by_subject.items():
subject_avg, total_weight = 0, 0
for score, weight in scores:
subject_avg += score * weight
total_weight += weight
score_sum += subject_avg / total_weight
score_count += 1
return score_sum / score_count
book = WeightedGradebook()
book.add_student('Albert Einstein')
book.report_grade('Albert Einstein', 'Math', 75, 0.05)
book.report_grade('Albert Einstein', 'Math', 65, 0.15)
book.report_grade('Albert Einstein', 'Math', 70, 0.80)
book.report_grade('Albert Einstein', 'Gym', 100, 0.40)
book.report_grade('Albert Einstein', 'Gym', 85, 0.60)
print(book.average_grade('Albert Einstein'))
>>>
80.25
超過一層的嵌套盡量就不要繼續用了。(維護噩夢)
應該重構成類。
grades = []
grades.append((95, 0.45))
grades.append((85, 0.55))
total = sum(score * weight for score, weight in grades)
total_weight = sum(weight for _, weight in grades)
average_grade = total / total_weight
如果要加一些教師評價,可能就會引入很多下劃線_:
grades = []
grades.append((95, 0.45, 'Great job'))
grades.append((85, 0.55, 'Better next time'))
total = sum(score * weight for score, weight, _ in grades)
total_weight = sum(weight for _, weight, _ in grades)
average_grade = total / total_weight
這里,namedtuple剛好符合要求:
from collections import namedtuple
Grade = namedtuple('Grade', ('score', 'weight'))
但是,namedtuple也有限制:
不能指定默認參數。
當你的數據有很多可選的屬性時,這點就很不好。屬性多的時候用built-in可能更合適。
namedtuple的屬性值仍可訪問。如果不能控制它們的使用,最好還是顯式地定義一個新的類。
class Subject:
def __init__(self):
self._grades = []
def report_grade(self, score, weight):
self._grades.append(Grade(score, weight))
def average_grade(self):
total, total_weight = 0, 0
for grade in self._grades:
total += grade.score * grade.weight
total_weight += grade.weight
return total / total_weight
class Student:
def __init__(self):
self._subjects = defaultdict(Subject)
def get_subject(self, name):
return self._subjects[name]
def average_grade(self):
total, count = 0, 0
for subject in self._subjects.values():
total += subject.average_grade()
count += 1
return total / count
class Gradebook:
def __init__(self):
self._students = defaultdict(Student)
def get_student(self, name):
return self._students[name]
book = Gradebook()
albert = book.get_student('Albert Einstein')
math = albert.get_subject('Math')
math.report_grade(75, 0.05)
math.report_grade(65, 0.15)
math.report_grade(70, 0.80)
gym = albert.get_subject('Gym')
gym.report_grade(100, 0.40)
gym.report_grade(85, 0.60)
print(albert.average_grade())
>>>
80.25
- Item38: 對于簡單的接口,接受函數而不是類
許多built-in的API允許傳遞函數。這些鉤子(hooks)被API回調。比如:sort函數的key參數可以傳遞函數:
names = ['Socrates', 'Archimedes', 'Plato', 'Aristotle']
names.sort(key=len)
print(names)
>>>
['Plato', 'Socrates', 'Aristotle', 'Archimedes']
當然,還有很多例子,比如defaultdict的參數也可以是類名或者函數,就是需要返回默認的值。
如果定義為每次返回0:
def log_missing():
print('Key added')
return 0
先構建出current的result,再增量地加回去。默認值為log_missing返回的0。
from collections import defaultdict
current = {'green': 12, 'blue': 3}
increments = [
('red', 5),
('blue', 17),
('orange', 9),
]
result = defaultdict(log_missing, current)
print('Before:', dict(result))
for key, amount in increments:
result[key] += amount
print('After: ', dict(result))
>>>
Before: {'green': 12, 'blue': 3}
Key added
Key added
After: {'green': 12, 'blue': 20, 'red': 5, 'orange': 9}
假如現在在添加的時候,需要統計添加的列別的數目,如下:(利用了閉包的屬性,可以在內部進行統計。)
def increment_with_report(current, increments):
added_count = 0
def missing():
nonlocal added_count # Stateful closure
added_count += 1
return 0
result = defaultdict(missing, current)
for key, amount in increments:
result[key] += amount
return result, added_count
盡管defaultdict不知道missing這個hook保持了什么狀態信息,最終結果也可以得到為2。
result, count = increment_with_report(current, increments)
assert count == 2
其它的語言可能可以定義一個類來保持狀態,然后傳遞這個實例的方法:
class CountMissing:
def __init__(self):
self.added = 0
def missing(self):
self.added += 1
return 0
同樣也是可以達到效果:
counter = CountMissing()
result = defaultdict(counter.missing, current) # Method ref
for key, amount in increments:
result[key] += amount
assert counter.added == 2
雖然類比閉包清晰一些,但是CountMissing類的目的不是很顯而易見,直到看到defaultdict的時候。(誰創建,誰調用missing,這個類未來需要其它的puclic方法嗎?)
python允許類定義__call__的方法,調用callable時,如果該類實現了__call__會返回true。
class BetterCountMissing:
def __init__(self):
self.added = 0
def __call__(self):
self.added += 1
return 0
counter = BetterCountMissing()
assert counter() == 0
assert callable(counter)
當key缺失的時候,會調用一次counter,即其call方法。
counter = BetterCountMissing()
result = defaultdict(counter, current) # Relies on __call__
for key, amount in increments:
result[key] += amount
assert counter.added == 2
這樣,就可以很方便的實現上面的需求。
- Item39: 用@classmethod多態來泛化(泛型)地構建對象
不止對象支持多態,類也同樣支持,有什么好處?
多態允許多個類在一個層級制度下實現它們自己的特有的版本。這意味著許多類可以提供不同的功能給同一個接口或者抽象類。
比如,現在在寫MapReduce的實現,要一個公共的抽象類來表示輸入數據:
class InputData:
def read(self):
raise NotImplementedError
從磁盤上的文件讀數據:
class PathInputData(InputData):
def __init__(self, path):
super().__init__()
self.path = path
def read(self):
with open(self.path) as f:
return f.read()
我可以有很多種InputData,比如NetworkInputData。而對于MapReduce的worker來說,需要輸入和消費這些數據:
class Worker:
def __init__(self, input_data):
self.input_data = input_data
self.result = None
def map(self):
raise NotImplementedError
def reduce(self, other):
raise NotImplementedError
此時,有一個具體的獲取行數的Worker:
class LineCountWorker(Worker):
def map(self):
data = self.input_data.read() # 讀數據
self.result = data.count('\n') # 當前數據的行數
def reduce(self, other):
self.result += other.result # 合并其它的Worker的結果。
似乎需要一個helper函數來生成數據。
import os
def generate_inputs(data_dir):
for name in os.listdir(data_dir):
yield PathInputData(os.path.join(data_dir, name))
然后根據這些數據,來生成worker:
def create_workers(input_list):
workers = []
for input_data in input_list:
workers.append(LineCountWorker(input_data))
return workers
然后調用map來分散到各個線程計算,最后用reduce來產生最終結果:
from threading import Thread
def execute(workers):
threads = [Thread(target=w.map) for w in workers]
for thread in threads: thread.start()
for thread in threads: thread.join()
first, *rest = workers
for worker in rest:
first.reduce(worker)
return first.result
最后把幾個helper連接到一起返回結果:
def mapreduce(data_dir):
inputs = generate_inputs(data_dir)
workers = create_workers(inputs)
return execute(workers)
隨機生成一些文件,發現可以工作得很好:
import os
import random
def write_test_files(tmpdir):
os.makedirs(tmpdir)
for i in range(100):
with open(os.path.join(tmpdir, str(i)), 'w') as f:
f.write('\n' * random.randint(0, 100))
tmpdir = 'test_inputs'
write_test_files(tmpdir)
result = mapreduce(tmpdir)
print(f'There are {result} lines')
>>>
There are 4360 lines
問題出現在哪?mapreduce方法不夠泛化。如果我要寫另一種InputData或者Worker的子類,需要重寫上面的幾個方法來匹配。
最好的方式是用類多態(因為init只有一個,對每個InputData的子類來寫適配的constructor不合理。)
使用了@classmethod來創建新的InputData:
class GenericInputData:
def read(self):
raise NotImplementedError
@classmethod
def generate_inputs(cls, config):
raise NotImplementedError
用config來找到字典值來處理:
class PathInputData(GenericInputData):
...
@classmethod
def generate_inputs(cls, config):
data_dir = config['data_dir']
for name in os.listdir(data_dir):
yield cls(os.path.join(data_dir, name))
類似地,可以創建泛型Worker。用cls()創建特定的子類。
class GenericWorker:
def __init__(self, input_data):
self.input_data = input_data
self.result = None
def map(self):
raise NotImplementedError
def reduce(self, other):
raise NotImplementedError
@classmethod
def create_workers(cls, input_class, config):
workers = []
for input_data in input_class.generate_inputs(config):
workers.append(cls(input_data))
return workers
注意到調用input_class.generate_inputs是類的多態。可以看到create_workers調用了cls()來提供額外的方式來構建GenericWorker(用到__init__)
class LineCountWorker(GenericWorker):
...
最后,重寫mapreduce函數:
def mapreduce(worker_class, input_class, config):
workers = worker_class.create_workers(input_class,
config)
return execute(workers)
config = {'data_dir': tmpdir}
result = mapreduce(LineCountWorker, PathInputData, config)
print(f'There are {result} lines')
>>>
There are 4360 lines
可以看出,通過@classmethod的cls可以建立具體類的連接。
- Item40: 用super來初始化父類
古老且簡單的方式來初始化父類是直接調用父類的__init__方法:
class MyBaseClass:
def __init__(self, value):
self.value = value
class MyChildClass(MyBaseClass):
def __init__(self):
MyBaseClass.__init__(self, 5)
但是在許多情況下失效。比如定義類來操作實例變量value。
class TimesTwo:
def __init__(self):
self.value *= 2
class PlusFive:
def __init__(self):
self.value += 5
構建的時候,繼承的時候是匹配結果的順序。
class OneWay(MyBaseClass, TimesTwo, PlusFive):
def __init__(self, value):
MyBaseClass.__init__(self, value)
TimesTwo.__init__(self)
PlusFive.__init__(self)
結果為:
foo = OneWay(5)
print('First ordering value is (5 * 2) + 5 =', foo.value)
>>>
First ordering value is (5 * 2) + 5 = 15
另一種是定義一樣的父類但是不一樣的順序:
class AnotherWay(MyBaseClass, PlusFive, TimesTwo):
def __init__(self, value):
MyBaseClass.__init__(self, value)
TimesTwo.__init__(self)
PlusFive.__init__(self)
定義和實現的順序不同。這種順序比較難發現,對于新手來說不友好。
bar = AnotherWay(5)
print('Second ordering value is', bar.value)
>>>
Second ordering value is 15
另一個問題發生在菱形繼承。比如兩個類繼承同一個類:
class TimesSeven(MyBaseClass):
def __init__(self, value):
MyBaseClass.__init__(self, value)
self.value *= 7
class PlusNine(MyBaseClass):
def __init__(self, value):
MyBaseClass.__init__(self, value)
self.value += 9
然后定義一個類繼承這兩個類:
class ThisWay(TimesSeven, PlusNine):
def __init__(self, value):
TimesSeven.__init__(self, value)
PlusNine.__init__(self, value)
foo = ThisWay(5)
print('Should be (5 * 7) + 9 = 44 but is', foo.value)
>>>
Should be (5 * 7) + 9 = 44 but is 14
由于__init__再次被調用,因此結果變為5+9=14,如果情況更復雜的話,這點是比較難以debug的。
為了解決這些問題,Python自帶了super自建的函數還有標準方法解析順序(MRO)。super確保了公共的父類只運行一次。MRO定義了父類被初始化的順序(以C3線性(C3 linearization)算法的順序進行)
class TimesSevenCorrect(MyBaseClass):
def __init__(self, value):
super().__init__(value)
self.value *= 7
class PlusNineCorrect(MyBaseClass):
def __init__(self, value):
super().__init__(value)
self.value += 9
現在,正確地運行如下:
class GoodWay(TimesSevenCorrect, PlusNineCorrect):
def __init__(self, value):
super().__init__(value)
foo = GoodWay(5)
print('Should be 7 * (5 + 9) = 98 and is', foo.value)
>>>
Should be 7 * (5 + 9) = 98 and is 98
順序看著是反著來的,實際是根據MRO的順序來的:
mro_str = '\n'.join(repr(cls) for cls in GoodWay.mro())
print(mro_str)
>>>
<class '__main__.GoodWay'>
<class '__main__.TimesSevenCorrect'>
<class '__main__.PlusNineCorrect'>
<class '__main__.MyBaseClass'>
<class 'object'>
super的兩個參數:MRO父視圖的類類型、訪問這個視圖的實例。
class ExplicitTrisect(MyBaseClass):
def __init__(self, value):
super(ExplicitTrisect, self).__init__(value)
self.value /= 3
對于object實例的初始化,參數不是要求的。(因為如果使用super(),編譯器會自動提供正確的參數__class__和self,因此,下面幾種都是等價的。)
class AutomaticTrisect(MyBaseClass):
def __init__(self, value):
super(__class__, self).__init__(value)
self.value /= 3
class ImplicitTrisect(MyBaseClass):
def __init__(self, value):
super().__init__(value)
self.value /= 3
assert ExplicitTrisect(9).value == 3
assert AutomaticTrisect(9).value == 3
assert ImplicitTrisect(9).value == 3
- Item41: 考慮用Mix-in類來組合功能性
最好還是避免多繼承,考慮編寫mix-in(定義了小的、額外的方法類,供子類使用)。
比如,假如現在需要從內存表示轉換Python對象到序列化的字典:
class ToDictMixin:
def to_dict(self):
return self._traverse_dict(self.__dict__)
用hasattr來進行動態屬性訪問,用isinstance來進行動態類檢查。并且訪問實例字典__dict__:
def _traverse_dict(self, instance_dict):
output = {}
for key, value in instance_dict.items():
output[key] = self._traverse(key, value)
return output
def _traverse(self, key, value):
if isinstance(value, ToDictMixin):
return value.to_dict()
elif isinstance(value, dict):
return self._traverse_dict(value)
elif isinstance(value, list):
return [self._traverse(key, i) for i in value]
elif hasattr(value, '__dict__'):
return self._traverse_dict(value.__dict__)
else:
return value
這里定義了一個類來使得字典表達為二叉樹:
class BinaryTree(ToDictMixin):
def __init__(self, value, left=None, right=None):
self.value = value
self.left = left
self.right = right
# 把大量的對象轉換成字典變得容易:
tree = BinaryTree(10,
left=BinaryTree(7, right=BinaryTree(9)),
right=BinaryTree(13, left=BinaryTree(11)))
print(tree.to_dict())
>>>
{'value': 10,
'left': {'value': 7,
'left': None,
'right': {'value': 9, 'left': None, 'right':
None}},
'right': {'value': 13,
'left': {'value': 11, 'left': None, 'right':
None},
'right': None}}
定義了BinaryTree的子類,帶著父節點的引用。這個循環引用可能會導致ToDictMixin.to_dict無限循環:
class BinaryTreeWithParent(BinaryTree):
def __init__(self, value, left=None,
right=None, parent=None):
super().__init__(value, left=left, right=right)
self.parent = parent
解決方案就是重寫(override)此類中的_traverse方法,使得方法只處理數值,避免mix-in帶來循環。這里給了父節點的數值,否則就用默認的實現。
def _traverse(self, key, value):
if (isinstance(value, BinaryTreeWithParent) and
key == 'parent'):
return value.value # Prevent cycles
else:
return super()._traverse(key, value)
調用BinaryTreeWithParent.to_dict沒有問題,因為循環引用的屬性不被允許:
root = BinaryTreeWithParent(10)
root.left = BinaryTreeWithParent(7, parent=root)
root.left.right = BinaryTreeWithParent(9, parent=root.left)
print(root.to_dict())
>>>
{'value': 10,
'left': {'value': 7,
'left': None,
'right': {'value': 9,
'left': None,
'right': None,
'parent': 7},
'parent': 10},
'right': None,
'parent': None}
可以使得擁有類型BinaryTreeWithParent的屬性的類自動和ToDictMixin工作得很好。
class NamedSubTree(ToDictMixin):
def __init__(self, name, tree_with_parent):
self.name = name
self.tree_with_parent = tree_with_parent
my_tree = NamedSubTree('foobar', root.left.right)
print(my_tree.to_dict()) # No infinite loop
>>>
{'name': 'foobar',
'tree_with_parent': {'value': 9,
'left': None,
'right': None,
'parent': 7}}
Mix-in可以被組合。比如,需要提供JSON序列化:
import json
class JsonMixin:
@classmethod
def from_json(cls, data):
kwargs = json.loads(data)
return cls(**kwargs)
def to_json(self):
return json.dumps(self.to_dict())
JsonMixin定義了兩個方法,下面是數據中心的拓撲結構:
class DatacenterRack(ToDictMixin, JsonMixin):
def __init__(self, switch=None, machines=None):
self.switch = Switch(**switch)
self.machines = [
Machine(**kwargs) for kwargs in machines]
class Switch(ToDictMixin, JsonMixin):
def __init__(self, ports=None, speed=None):
self.ports = ports
self.speed = speed
class Machine(ToDictMixin, JsonMixin):
def __init__(self, cores=None, ram=None, disk=None):
self.cores = cores
self.ram = ram
self.disk = disk
這里測試了從json中加載對象,然后序列化回json的整個閉環:
serialized = """{
"switch": {"ports": 5, "speed": 1e9},
"machines": [
{"cores": 8, "ram": 32e9, "disk": 5e12},
{"cores": 4, "ram": 16e9, "disk": 1e12},
{"cores": 2, "ram": 4e9, "disk": 500e9}
]
}"""
deserialized = DatacenterRack.from_json(serialized)
roundtrip = deserialized.to_json()
assert json.loads(serialized) == json.loads(roundtrip)
可以看出,用這種插件類的方式,也可以實現很多靈活性。
- Item42: 使用公有屬性而不是私有屬性
在Python中,有兩種可見性:public和private
class MyObject:
def __init__(self):
self.public_field = 5
self.__private_field = 10
def get_private_field(self):
return self.__private_field
公有直接訪問:
foo = MyObject()
assert foo.public_field == 5
私有通過get方法獲得:
assert foo.get_private_field() == 10
直接訪問會引發Error:
foo.__private_field
>>>
Traceback ...
AttributeError: 'MyObject' object has no attribute '__private_field'
類方法同樣有訪問私有屬性的權限,因為它們在類內被聲明:
class MyOtherObject:
def __init__(self):
self.__private_field = 71
@classmethod
def get_private_field_of_instance(cls, instance):
return instance.__private_field
bar = MyOtherObject()
assert MyOtherObject.get_private_field_of_instance(bar) == 71
繼承訪問不到父類的私有域:
class MyParentObject:
def __init__(self):
self.__private_field = 71
class MyChildObject(MyParentObject):
def get_private_field(self):
return self.__private_field
baz = MyChildObject()
baz.get_private_field()
>>>
Traceback ...
AttributeError: 'MyChildObject' object has no attribute
'_MyChildObject__private_field'
私有域的實現是簡單地把屬性名做了個轉換。比如__private_field其實被轉換成_MyChildObject__private_field。如果是指代父類的__private_field,則是被轉換成了_MyParentObject__private_field。知道這個規則的話,就可以直接訪問到對應的屬性值了:
assert baz._MyParentObject__private_field == 71
或者直接通過__dict__來查看類內的屬性:
print(baz.__dict__)
>>>
{'_MyParentObject__private_field': 71}
Python為了功能性,用戶實際上可以繞開private。
根據Item2的PEP8的風格指引:一個下劃線_protected_field表示保護域,表示使用類的外界用戶需要小心處理。而私有域則是不希望被外界使用和繼承。
class MyStringClass:
def __init__(self, value):
self.__value = value
def get_value(self):
return str(self.__value)
foo = MyStringClass(5)
assert foo.get_value() == '5'
這是錯誤的方式。
class MyIntegerSubclass(MyStringClass):
def get_value(self):
return int(self._MyStringClass__value)
foo = MyIntegerSubclass('5')
assert foo.get_value() == 5
class MyBaseClass:
def __init__(self, value):
self.__value = value
def get_value(self):
return self.__value
class MyStringClass(MyBaseClass):
def get_value(self):
return str(super().get_value()) # Updated
class MyIntegerSubclass(MyStringClass):
def get_value(self):
return int(self._MyStringClass__value) # Not updated
foo = MyIntegerSubclass(5)
foo.get_value()
>>>
Traceback ...
AttributeError: 'MyIntegerSubclass' object has no attribute
'_MyStringClass__value'
最好還是以protected的形式,同時給出注釋,告訴其他人這是內部的。
class MyStringClass:
def __init__(self, value):
# This stores the user-supplied value for the object.
# It should be coercible to a string. Once assigned
in
# the object it should be treated as immutable.
self._value = value
...
需要考慮的是使用私有屬性來區分變量名:
class ApiClass:
def __init__(self):
self._value = 5
def get(self):
return self._value
class Child(ApiClass):
def __init__(self):
super().__init__()
self._value = 'hello' # Conflicts
a = Child()
print(f'{a.get()} and {a._value} should be different')
>>>
hello and hello should be different
為了減少變量名被覆蓋的風險,區別域是一種可行的選擇:
class ApiClass:
def __init__(self):
self.__value = 5 # Double underscore
def get(self):
return self.__value # Double underscore
class Child(ApiClass):
def __init__(self):
super().__init__()
self._value = 'hello' # OK!
a = Child()
print(f'{a.get()} and {a._value} are different')
>>>
5 and hello are different
- Item43: 繼承collections.abc,來定制Container類型
每個Python類是一個容器,封裝屬性和功能。同時內部還提供了很多的容器類型(比如:list,tuple,set和dict)。比如現在要統計元素的頻率:
class FrequencyList(list):
def __init__(self, members):
super().__init__(members)
def frequency(self):
counts = {}
for item in self:
counts[item] = counts.get(item, 0) + 1
return counts
通過繼承list,可以得到list的基礎功能。然后可以定義方法來提供定制的功能:
foo = FrequencyList(['a', 'b', 'a', 'c', 'b', 'a', 'd'])
print('Length is', len(foo))
foo.pop()
print('After pop:', repr(foo))
print('Frequency:', foo.frequency())
>>>
Length is 7
After pop: ['a', 'b', 'a', 'c', 'b', 'a']
Frequency: {'a': 3, 'b': 2, 'c': 1}
現在,假設我要提供一個類似list的取下標功能,但是針對二叉樹的結點:
class BinaryNode:
def __init__(self, value, left=None, right=None):
self.value = value
self.left = left
self.right = right
如何使得這個類像序列一樣工作?即:
bar = [1, 2, 3]
bar[0]
# 實際上就是:
bar.__getitem__(0)
可以提供__getitem__的實現:使用前序遍歷,每次記錄index。
class IndexableNode(BinaryNode):
def _traverse(self):
if self.left is not None:
yield from self.left._traverse()
yield self
if self.right is not None:
yield from self.right._traverse()
def __getitem__(self, index):
for i, item in enumerate(self._traverse()):
if i == index:
return item.value
raise IndexError(f'Index {index} is out of range')
可以構建二叉樹如下:
tree = IndexableNode(
10,
left=IndexableNode(
5,
left=IndexableNode(2),
right=IndexableNode(
6,
right=IndexableNode(7))),
right=IndexableNode(
15,
left=IndexableNode(11)))
可以像list一樣進行訪問:
print('LRR is', tree.left.right.right.value)
print('Index 0 is', tree[0])
print('Index 1 is', tree[1])
print('11 in the tree?', 11 in tree)
print('17 in the tree?', 17 in tree)
print('Tree is', list(tree))
>>>
LRR is 7
Index 0 is 2
Index 1 is 5
11 in the tree? True
17 in the tree? False
Tree is [2, 5, 6, 7, 10, 11, 15]
問題是實現了__getitem__對于list的功能并不齊全,比如:
len(tree)
>>>
Traceback ...
TypeError: object of type 'IndexableNode' has no len()
此時要實現__len__:
class SequenceNode(IndexableNode):
def __len__(self):
for count, _ in enumerate(self._traverse(), 1):
pass
return count
tree = SequenceNode(
10,
left=SequenceNode(
5,
left=SequenceNode(2),
right=SequenceNode(
6,
right=SequenceNode(7))),
right=SequenceNode(
15,
left=SequenceNode(11))
)
print('Tree length is', len(tree))
>>>
Tree length is 7
不幸的是,count和index方法還是無法使用。這就使得自己定義容器類比較困難。為了避免這個困難,collections.abc有一系列的抽象類提供:
from collections.abc import Sequence
class BadType(Sequence):
pass
foo = BadType()
>>>
Traceback ...
TypeError: Can't instantiate abstract class BadType with abstract methods __getitem__, __len__
同時繼承Sequence,可以滿足一些方法,比如index,count等的使用:
class BetterNode(SequenceNode, Sequence):
pass
tree = BetterNode(
10,
left=BetterNode(
5,
left=BetterNode(2),
right=BetterNode(
6,
right=BetterNode(7))),
right=BetterNode(
15,
left=BetterNode(11))
)
print('Index of 7 is', tree.index(7))
print('Count of 10 is', tree.count(10))
>>>
Index of 7 is 3
Count of 10 is 1
還有更多的比如Set和MutableMapping,可以來實現來匹配Python自建的容器類。排序也是如此(見Item73)