组合比继承更灵活,因为它可以建模松散耦合的关系。对组件类的更改对复合类影响很小或没有影响。基于组成的设计更适合更改
在本部分中,您将使用合成来实现仍然符合PayrollSystem
和ProductivitySystem
要求的更好的设计
# In productivity.py
class ProductivitySystem:
def __init__(self):
self._roles = {
'manager': ManagerRole,
'secretary': SecretaryRole,
'sales': SalesRole,
'factory': FactoryRole,
}
def get_role(self, role_id):
role_type = self._roles.get(role_id)
if not role_type:
raise ValueError('role_id')
return role_type()
def track(self, employees, hours):
print('Tracking Employee Productivity')
print('==============================')
for employee in employees:
employee.work(hours)
print('')
ProductivitySystem
类使用映射到实现角色的角色类的字符串标识符来定义一些角色。它公开一个.get_role()
方法,该方法在给定角色标识符的情况下,返回角色类型对象。如果没有找到该角色,则会引发ValueError
异常
# In productivity.py
class ManagerRole:
def perform_duties(self, hours):
return f'screams and yells for {hours} hours.'
class SecretaryRole:
def perform_duties(self, hours):
return f'does paperwork for {hours} hours.'
class SalesRole:
def perform_duties(self, hours):
return f'expends {hours} hours on the phone.'
class FactoryRole:
def perform_duties(self, hours):
return f'manufactures gadgets for {hours} hours.'
您实现的每个角色都公开了一个.perform_duties()
,它占用了工作的小时数。这些方法返回一个表示职责的字符串
角色类彼此独立,但它们公开相同的接口,因此它们是可互换的。稍后您将看到如何在应用程序中使用它们
# In hr.py
class PayrollSystem:
def __init__(self):
self._employee_policies = {
1: SalaryPolicy(3000),
2: SalaryPolicy(1500),
3: CommissionPolicy(1000, 100),
4: HourlyPolicy(15),
5: HourlyPolicy(9)
}
def get_policy(self, employee_id):
policy = self._employee_policies.get(employee_id)
if not policy:
return ValueError(employee_id)
return policy
def calculate_payroll(self, employees):
print('Calculating Payroll')
print('===================')
for employee in employees:
print(f'Payroll for: {employee.id} - {employee.name}')
print(f'- Check amount: {employee.calculate_payroll()}')
if employee.address:
print('- Sent to:')
print(employee.address)
print('')
PayrollSystem
为每个员工保留一个工资政策的内部数据库。它公开一个.get_policy()
,给定一个员工id,返回其工资单策略。如果系统中不存在指定的id,则该方法将引发ValueError
异常
calculate_payroll()
的实现与以前的工作方式相同。它获取一个雇员列表,计算工资单,并打印结果
# In hr.py
class PayrollPolicy:
def __init__(self):
self.hours_worked = 0
def track_work(self, hours):
self.hours_worked += hours
class SalaryPolicy(PayrollPolicy):
def __init__(self, weekly_salary):
super().__init__()
self.weekly_salary = weekly_salary
def calculate_payroll(self):
return self.weekly_salary
class HourlyPolicy(PayrollPolicy):
def __init__(self, hour_rate):
super().__init__()
self.hour_rate = hour_rate
def calculate_payroll(self):
return self.hours_worked * self.hour_rate
class CommissionPolicy(SalaryPolicy):
def __init__(self, weekly_salary, commission_per_sale):
super().__init__(weekly_salary)
self.commission_per_sale = commission_per_sale
@property
def commission(self):
sales = self.hours_worked / 5
return sales * self.commission_per_sale
def calculate_payroll(self):
fixed = super().calculate_payroll()
return fixed + self.commission
首先,您要实现一个PayrollPolicy类,该类充当所有薪资策略的基类。此类跟踪工作小时数,这是所有工资单政策所共有的
其他策略类源自PayrollPolicy
。我们在这里使用继承是因为我们想利用PayrollPolicy
的实现。此外,SalaryPolicy
,HourlyPolicy
和CommissionPolicy
也是PayrollPolicy
SalaryPolicy
使用weekly_salary
值初始化,然后在.calculate_payroll()
中使用该值。HourlyPolicy
使用hour_rate
初始化,并通过利用基本类hours_working
实现.calculate_payroll()
CommissionPolicy
类派生自SalaryPolicy
,因为它希望继承其实现。它是用weekly_salary
参数初始化的,但是它还需要一个common_per_sale
参数
使用common_per_sale
来计算.commission
,它被实现为一个属性,因此在请求时计算它。在这个例子中,我们假设每5小时工作一次,而.commission
是销售的数量乘以commission_per_sale
值
首先利用SalaryPolicy
中的实现,然后添加计算佣金,从而实现.calculate_payroll()
方法。
# In contacts.py
class AddressBook:
def __init__(self):
self._employee_addresses = {
1: Address('121 Admin Rd.', 'Concord', 'NH', '03301'),
2: Address('67 Paperwork Ave', 'Manchester', 'NH', '03101'),
3: Address('15 Rose St', 'Concord', 'NH', '03301', 'Apt. B-1'),
4: Address('39 Sole St.', 'Concord', 'NH', '03301'),
5: Address('99 Mountain Rd.', 'Concord', 'NH', '03301'),
}
def get_employee_address(self, employee_id):
address = self._employee_addresses.get(employee_id)
if not address:
raise ValueError(employee_id)
return address
AddressBook
类为每个员工保留一个Address
对象的内部数据库。它公开一个get_employee_address()
方法,该方法返回指定员工id的地址。如果员工id不存在,则会引发一个ValueError
错误
# In contacts.py
class Address:
def __init__(self, street, city, state, zipcode, street2=''):
self.street = street
self.street2 = street2
self.city = city
self.state = state
self.zipcode = zipcode
def __str__(self):
lines = [self.street]
if self.street2:
lines.append(self.street2)
lines.append(f'{self.city}, {self.state} {self.zipcode}')
return '\n'.join(lines)
该类管理地址组件并提供地址的漂亮表示形式
到目前为止,已经扩展了新类以支持更多功能,但是对以前的设计没有重大更改。这将随着员工模块及其类的设计而改变