Python OOP
Basic
概念
面向对象编程(OOP)是一种基于"对象"概念的编程范式,它将数据(属性)和操作数据的方法(行为)捆绑在一起。
对象与类的关系
- 类(Class): 是创建对象的蓝图或模板,定义了对象的属性和方法
- 对象(Object): 是类的实例,具有类定义的属性和方法
class Vector2D:
""" 定义类:二维向量——最小可运行示例 """
__slots__ = ('x', 'y') # 节省内存,防止随意加属性
def __init__(self, x=0, y=0):
self.x, self.y = float(x), float(y)
def __repr__(self):
return f'Vector2D({self.x}, {self.y})'
# 重载运算符:让向量支持 + - * /
def __add__(self, other):
return Vector2D(self.x + other.x, self.y + other.y)
def __mul__(self, scalar):
return Vector2D(self.x * scalar, self.y * scalar)
# 自定义协议:len()、bool()
def __len__(self):
return int((self.x ** 2 + self.y ** 2) ** 0.5)
def __bool__(self):
return bool(len(self))
# 使用类创建对象
v1 = Vector2D(3, 4)
v2 = Vector2D(1, 2)
print(v1 + v2 * 2) # Vector2D(5.0, 8.0)
封装
数据隐藏: 限制对对象内部状态的直接访问
接口暴露: 提供公共方法来操作对象
公共成员(Public):无下划线,可在任何地方访问
保护成员(Protected):单下划线_开头,约定为"内部使用"(非强制)
私有成员(Private):双下划线__开头,会触发名称修饰(name mangling)
class Account:
def __init__(self, owner, balance=0):
self.owner = owner
self.__balance = balance # 私有属性
@property
def balance(self):
return self.__balance
def deposit(self, amount):
if amount <= 0:
raise ValueError('amount > 0')
self.__balance += amount
def withdraw(self, amount):
if amount > self.__balance:
raise ValueError('insufficient funds')
self.__balance -= amount
继承
继承允许新类(子类)获取现有类(父类)的属性和方法,并可以扩展或修改它们。
继承类型:
- 单继承
- 多重继承
- 多级继承
class Animal:
def __init__(self, name):
self.name = name
def eat(self):
print(f"{self.name} is eating")
def sleep(self):
print(f"{self.name} is sleeping")
class Pet(Animal):
def __init__(self, name, owner):
super().__init__(name)
self.owner = owner
def play(self):
print(f"{self.name} is playing with {self.owner}")
class Dog(Pet):
def __init__(self, name, owner, breed):
super().__init__(name, owner)
self.breed = breed
def bark(self):
print(f"{self.name} says woof!")
# 方法重写
def play(self):
print(f"{self.name} the {self.breed} is fetching the ball for {self.owner}")
# 使用继承
buddy = Dog("Buddy", "Alice", "Golden Retriever")
buddy.eat() # 继承自Animal
buddy.sleep() # 继承自Animal
buddy.play() # Dog重写的方法
buddy.bark() # Dog新增的方法
多重继承 & MRO(方法解析顺序)
class A:
value = 'A'
class B(A):
value = 'B'
class C(A):
value = 'C'
class D(B, C):
pass
print(D.value) # B
print(D.__mro__) # (<class '__main__.D'>, <class '__main__.B'>, <class '__main__.C'>, <class '__main__.A'>, <class 'object'>)
多态
多态允许不同类的对象对相同消息(方法调用)做出不同响应。
- 方法重写(Override)
- 鸭子类型(Duck Typing)
from abc import ABC, abstractmethod
class Animal(ABC):
@abstractmethod
def speak(self): ...
class Dog(Animal):
def speak(self): return 'woof!'
class Cat(Animal):
def speak(self): return 'meow!'
def animal_talk(animal: Animal):
print(animal.speak())
for pet in (Dog(), Cat()):
animal_talk(pet) # 输出 woof! meow!
抽象类
抽象类是不能实例化的类,用于定义子类必须实现的接口
from abc import ABC, abstractmethod
class Shape(ABC):
@abstractmethod
def area(self):
pass
@abstractmethod
def perimeter(self):
pass
class Rectangle(Shape):
def __init__(self, width, height):
self.width = width
self.height = height
def area(self):
return self.width * self.height
def perimeter(self):
return 2 * (self.width + self.height)
接口
Python中没有专门的接口语法,通常使用抽象类或协议(Protocol)实现。
from typing import Protocol
class Drawable(Protocol):
def draw(self) -> None:
...
class Circle:
def draw(self):
print("Drawing Circle")
def render(drawable: Drawable):
drawable.draw()
接口和抽象类的区别
特性 | 接口(Interface) | 抽象类(Abstract Class) |
---|---|---|
定义方式 | 通过普通类或ABC模块实现 | 必须继承abc.ABC 或使用metaclass=ABCMeta |
方法实现 | 所有方法都是未实现的 | 可以包含具体实现方法和抽象方法 |
多继承 | 更适合多继承场景 | 多继承时可能产生菱形问题 |
设计目的 | 定义行为契约 | 提供部分实现,要求子类完成剩余部分 |
实例化 | 不能实例化 | 不能实例化 |
Python版本支持 | 无原生接口,通过协议或ABC模拟 | Python 2.6+ 通过ABC模块支持 |
使用场景 | 强调"能做什么"的契约 | 强调"是什么"的层级关系 |
@abstractmethod | 可选(如果用ABC实现接口) | 必须修饰抽象方法 |
属性定义 | 通常只定义方法 | 可以定义实例属性、类属性 |
from abc import ABC, abstractmethod
# 接口风格实现(通过抽象类模拟)
class IShape(ABC):
@abstractmethod
def area(self):
pass
@abstractmethod
def perimeter(self):
pass
# 抽象类实现
class Animal(ABC):
def __init__(self, name):
self.name = name
@abstractmethod
def make_sound(self):
pass
def eat(self):
print(f"{self.name} is eating")
# 具体实现
class Dog(Animal):
def make_sound(self):
print("Bark!")
class Circle(IShape):
def __init__(self, radius):
self.radius = radius
def area(self):
return 3.14 * self.radius ** 2
def perimeter(self):
return 2 * 3.14 * self.radius
SOLID & 设计原则
缩写 | 全称 | Python 示例 |
---|---|---|
SRP | 单一职责 | 一个类只干一件「领域」事,如 Downloader 只管下载,解析另开 Parser |
OCP | 开闭原则 | 用策略模式+插件,新增功能不改旧代码 |
LSP | 里氏替换 | 子类不能更严格(不能抛新异常、不能削弱前置条件) |
ISP | 接口隔离 | Protocol 拆分小而专的接口,如 Readable , Writable |
DIP | 依赖反转 | 高层模块依赖抽象(Downloader 接口)而非具体实现(requests ) |
- 用 Protocol 做结构化子类型(无需显式继承即可被视为子类型)。
- 用 TypeVar(bound=...) 表达「上界」,实现泛型逆变/协变。
# OCP示例
from abc import ABC, abstractmethod
class Shape(ABC):
@abstractmethod
def area(self):
pass
class Rectangle(Shape):
def __init__(self, width, height):
self.width = width
self.height = height
def area(self):
return self.width * self.height
class Circle(Shape):
def __init__(self, radius):
self.radius = radius
def area(self):
return 3.14 * self.radius ** 2
def total_area(shapes: list[Shape]):
return sum(shape.area() for shape in shapes)
@dataclass
数据类(dataclass)是Python 3.7+引入的强大特性,它极大地简化了类的定义过程,特别适合主要用来存储数据的类
结合类型提示(Type Hints),可以写出既简洁又类型安全的代码。
基本用法
@dataclass装饰器自动为类生成特殊方法,如__init__
、__repr__
、__eq__
等。
from dataclasses import dataclass
@dataclass
class Point:
x: float
y: float
z: float = 0.0 # 默认值
p = Point(1.5, 2.5)
print(p) # 自动生成__repr__: Point(x=1.5, y=2.5, z=0.0)
主要自动生成的方法
__init__()
: 基于类型提示的构造函数__repr__()
: 友好的字符串表示__eq__()
: 基于字段值的相等性比较__hash__()
: 如果frozen=True则生成
slots优化
slots=True可以显著减少内存使用并提高属性访问速度。
@dataclass(slots=True)
class InventoryItem:
name: str
unit_price: float
quantity: int = 0
default_factory参数
dataclasses模块中field()函数的一个重要参数,它解决了Python类中可变默认值的常见问题,并提供了更灵活的默认值生成方式。
default_factory接受一个可调用对象(通常是一个函数或类),在每次创建实例时调用它来生成默认值
传统的问题:
class RegularClass:
def __init__(self, items=[]): # 危险!可变默认值
self.items = items
a = RegularClass()
a.items.append(1)
b = RegularClass()
print(b.items) # [1] - 这不是我们期望的!
传统解决
class RegularClass:
def __init__(self, items=None):
self.items = items if items is not None else []
# 使用自定义函数
from dataclasses import dataclass, field
@dataclass
class ShoppingCart:
items: list = field(default_factory=list) # 每次创建新列表
from uuid import uuid4
@dataclass
class User:
id: str = field(default_factory=lambda: uuid4().hex)
name: str
# 使用类构造函数
@dataclass
class Department:
name: str
employees: list = field(default_factory=list)
created_at: dict = field(default_factory=dict)
# 复杂数据结构
from typing import DefaultDict
@dataclass
class Inventory:
stock: DefaultDict[str, int] = field(
default_factory=lambda: defaultdict(int)
)
# 嵌套数据结构
@dataclass
class TreeNode:
value: int
children: list = field(default_factory=list)
root = TreeNode(1)
root.children.append(TreeNode(2))
root.children.append(TreeNode(3))
# 配置系统
@dataclass
class AppConfig:
settings: dict = field(default_factory=dict)
plugins: list = field(default_factory=list)
debug: bool = False
# 缓存系统
from datetime import datetime
@dataclass
class Cache:
data: dict = field(default_factory=dict)
last_updated: datetime = field(default_factory=datetime.now)
特性 | default | default_factory |
---|---|---|
赋值时机 | 类定义时 | 实例创建时 |
适用类型 | 不可变对象 | 可变对象或需要动态生成的值 |
共享风险 | 有(如果是可变对象) | 无 |
性能 | 稍高(只计算一次) | 稍低(每次实例化都调用) |
典型用例 | count: int = 0 | items: list = field(default_factory=list) |
语法 | 直接赋值 | 必须通过field() 指定 |
内存效率 | 较高(共享默认值) | 较低(每个实例独立) |
线程安全性 | 不安全(共享可变状态) | 安全(独立实例) |
动态值支持 | 不支持 | 支持(可调用对象) |
不可变数据类
frozen=True创建不可变实例(类似namedtuple)。
@dataclass(frozen=True)
class ImmutablePoint:
x: float
y: float
p = ImmutablePoint(1.0, 2.0)
# p.x = 3.0 # 报错: dataclasses.FrozenInstanceError
字段控制
from dataclasses import field
@dataclass
class C:
x: int
y: int = field(repr=False) # 不在__repr__中显示
z: list = field(default_factory=list) # 避免可变默认值问题
@property
@property 是一个 数据描述符(实现了 __get__/__set__/__delete__
的特殊类)。
被装饰的方法在类级别会被替换成 property 实例;访问属性时触发描述符协议。
property(fget=None, fset=None, fdel=None, doc=None)
参数 | 作用 | 默认值 |
---|---|---|
fget | 读取(getter)函数 | None ⇒ 只写属性 |
fset | 写入(setter)函数 | None ⇒ 只读属性 |
fdel | 删除(deleter)函数 | None ⇒ 不可 del obj.attr |
doc | 该属性的 docstring;也用于 help() 及 IDE 提示 | None |
最小可运行模型(手搓版 property)
class MyProperty:
"""山寨版 @property,帮助理解底层"""
def __init__(self, fget=None, fset=None, fdel=None, doc=None):
self.fget = fget
self.fset = fset
self.fdel = fdel
self.__doc__ = doc
def __get__(self, obj, objtype=None):
if obj is None:
return self
if self.fget is None:
raise AttributeError('unreadable')
return self.fget(obj)
def __set__(self, obj, value):
if self.fset is None:
raise AttributeError('can\'t set')
self.fset(obj, value)
def setter(self, fset):
return type(self)(self.fget, fset, self.fdel, self.__doc__)
class Demo:
def __init__(self, x):
self._x = x
@MyProperty
def x(self):
return self._x
@x.setter
def x(self, value):
self._x = float(value)
d = Demo(3)
print(d.x) # 3.0
d.x = 5
print(d.x) # 5.0
只读属性(最常见)
class Celsius:
def __init__(self, temp):
self._temp = temp
@property
def temperature(self):
"""返回摄氏度,外部无法赋值"""
return self._temp
读写属性(带校验)
class Celsius:
def __init__(self, temp=0):
self.temperature = temp # 走 setter
@property
def temperature(self):
return self._temp
@temperature.setter
def temperature(self, value):
if value < -273.15:
raise ValueError('below absolute zero')
self._temp = float(value)
@temperature.deleter
def temperature(self):
print('delete _temp')
del self._temp
缓存属性(惰性求值 & 只计算一次)
import functools, math
class Circle:
def __init__(self, radius):
self.radius = radius
@property
@functools.cache # 3.9+;老版本用 functools.lru_cache
def area(self):
print('heavy compute...')
return math.pi * self.radius ** 2
c = Circle(2)
print(c.area) # heavy compute... 12.566
print(c.area) # 12.566 不再打印
实际案例
@property 把 User 的 内部状态 包装成 安全、智能、易用的公开接口 ——
外部像访问普通属性一样简单,内部却能做校验、缓存、脱敏,实现真正的封装。
封装一个 “用户 User” 类,把「只读、可读写、校验、缓存、脱敏」一次写完,并给出完整可执行代码 + 单元测试。
字段 | 规则 | 实现方式 |
---|---|---|
uid | 只读,生成后不能改 | 只写 @property |
email | 必须符合正则 | @email.setter 里校验 |
_pwd_hash | 内部存储,对外不可见 | 私有变量 |
password | 写入时哈希,读取时报错 | @password.setter 只写 |
display_name | 为空时回退到邮箱前缀 | getter 里计算 |
age | 0-120 岁 | setter 里校验 |
profile_complete | 缓存属性,是否完整 | @cached_property |
import re
import hashlib
from functools import cached_property
from dataclasses import dataclass
@dataclass(slots=True)
class User:
_email: str
_pwd_hash: str
_age: int | None = None
_display_name: str | None = None
# ---------- uid:只读 ----------
@property
def uid(self) -> str:
"""全局唯一 ID(只读)"""
# 用邮箱哈希做 demo,实际可用 uuid
return hashlib.sha256(self._email.encode()).hexdigest()[:8]
# ---------- email:读写 + 正则校验 ----------
@property
def email(self) -> str:
return self._email
@email.setter
def email(self, value: str) -> None:
if not re.fullmatch(r'^[^@]+@[^@]+\.[^@]+$', value):
raise ValueError('Invalid email')
self._email = value
# ---------- password:仅写 ----------
@property
def password(self):
raise AttributeError('password is write-only')
@password.setter
def password(self, plain: str) -> None:
if len(plain) < 6:
raise ValueError('password too short')
self._pwd_hash = hashlib.sha256(plain.encode()).hexdigest()
def check_password(self, plain: str) -> bool:
return self._pwd_hash == hashlib.sha256(plain.encode()).hexdigest()
# ---------- display_name:计算属性 ----------
@property
def display_name(self) -> str:
return self._display_name or self.email.split('@')[0]
@display_name.setter
def display_name(self, value: str) -> None:
self._display_name = value.strip() or None
# ---------- age:范围校验 ----------
@property
def age(self) -> int | None:
return self._age
@age.setter
def age(self, value: int | None) -> None:
if value is not None and not 0 <= value <= 120:
raise ValueError('age must be 0-120')
self._age = value
# ---------- profile_complete:缓存 ----------
@cached_property
def profile_complete(self) -> bool:
print('expensive check...')
return self.age is not None and self._display_name is not None
# ----------------- 使用演示 -----------------
if __name__ == '__main__':
u = User('alice@example.com', 'legacy')
u.password = '123456'
u.age = 25
u.display_name = 'Alice'
print(u.uid) # 只读
print(u.display_name) # Alice
print(u.profile_complete) # expensive check... True
print(u.profile_complete) # 不再打印,已缓存
实战
爬虫框架
带事件系统、插件机制、策略路由的复杂爬虫框架
- 支持不同下载策略(requests、httpx、aiohttp)
- 支持不同解析策略(lxml、bs4、re)
- 支持插件:重试、限速、日志
- 事件总线:下载前、解析后、错误时
classDiagram
SpiderFramework ..> EventBus
SpiderFramework ..> Downloader
SpiderFramework ..> Parser
SpiderFramework ..> Plugin
Downloader <|-- RequestsDownloader
Downloader <|-- HttpxDownloader
Parser <|-- LxmlParser
Parser <|-- BS4Parser
Plugin <|-- RetryPlugin
Plugin <|-- RateLimitPlugin
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Type
import asyncio, time, httpx, requests, re, random
# ---------- 事件系统 ----------
class EventBus:
def __init__(self):
self.handlers = {}
def on(self, event: str, handler):
self.handlers.setdefault(event, []).append(handler)
def emit(self, event: str, *args, **kw):
for h in self.handlers.get(event, []):
h(*args, **kw)
# ---------- 插件 ----------
class Plugin(ABC):
def __init__(self, bus: EventBus):
self.bus = bus
self.register()
@abstractmethod
def register(self): ...
class RetryPlugin(Plugin):
def register(self):
self.bus.on('download_error', self.retry)
def retry(self, url, exception, retry=3, delay=1):
if retry:
time.sleep(delay)
self.bus.emit('download', url, retry - 1)
class RateLimitPlugin(Plugin):
def __init__(self, bus, rps=5):
super().__init__(bus)
self.interval = 1 / rps
self.last = 0
def register(self):
self.bus.on('before_download', self.throttle)
def throttle(self, url):
while (gap := time.time() - self.last) < self.interval:
time.sleep(self.interval - gap)
self.last = time.time()
# ---------- 下载器 ----------
class Downloader(ABC):
@abstractmethod
def fetch(self, url): ...
class RequestsDownloader(Downloader):
def fetch(self, url):
return requests.get(url, timeout=5).text
class HttpxDownloader(Downloader):
def fetch(self, url):
return httpx.get(url, timeout=5).text
class AiohttpDownloader(Downloader):
async def fetch(self, url):
import aiohttp
async with aiohttp.ClientSession() as s:
async with s.get(url) as r:
return await r.text()
# ---------- 解析器 ----------
class Parser(ABC):
@abstractmethod
def parse(self, html) -> List[str]: ...
class RegexTitleParser(Parser):
def parse(self, html):
return re.findall('<title>(.*?)</title>', html, flags=re.I)
class BS4TitleParser(Parser):
def parse(self, html):
from bs4 import BeautifulSoup
return [BeautifulSoup(html, 'lxml').title.string]
# ---------- 框架 ----------
class SpiderFramework:
def __init__(self, downloader: Type[Downloader], parser: Type[Parser],
plugins: List[Type[Plugin]] = None):
self.bus = EventBus()
for P in plugins or []:
P(self.bus)
self.downloader = downloader()
self.parser = parser()
def run(self, urls: List[str]):
for url in urls:
self.bus.emit('before_download', url)
try:
html = self.downloader.fetch(url)
titles = self.parser.parse(html)
self.bus.emit('after_parse', url, titles)
except Exception as e:
self.bus.emit('download_error', url, e)
# ---------- 运行 ----------
if __name__ == '__main__':
urls = ['https://example.com', 'https://httpbin.org/html']
spider = SpiderFramework(
downloader=HttpxDownloader,
parser=BS4TitleParser,
plugins=[RetryPlugin, RateLimitPlugin]
)
spider.run(urls)
时序图:一次完整的 URL 抓取流程
sequenceDiagram
autonumber
participant Main as 主程序
participant Spider as SpiderFramework
participant Bus as EventBus
participant Rate as RateLimitPlugin
participant Retry as RetryPlugin
participant Downloader as HttpxDownloader
participant Parser as BS4TitleParser
loop 遍历 URL 列表
Main->>Spider: run([url1, url2])
Spider->>Bus: emit("before_download", url)
Bus-->>Rate: throttle(url)
Rate-->>Rate: sleep if needed
Spider->>Downloader: fetch(url)
alt 成功
Downloader-->>Spider: html
Spider->>Parser: parse(html)
Parser-->>Spider: titles
Spider->>Bus: emit("after_parse", url, titles)
else 异常
Downloader--xSpider: Exception
Spider->>Bus: emit("download_error", url, e)
Bus-->>Retry: retry(url, e, retry=3)
Retry-->>Bus: emit("download", url, retry-1)
Bus-->>Spider: 重新进入循环
end
end
活动图:事件总线内部的分叉与合并
flowchart TD
A[Spider.run 启动] --> B{遍历每个 URL}
B --> C[emit before_download]
C --> D[RateLimitPlugin.throttle]
D --> E[Downloader.fetch]
E --> F{成功?}
F -->|yes| G[Parser.parse]
G --> H[emit after_parse]
F -->|no| I[emit download_error]
I --> J{Retry>0?}
J -->|yes| K[RetryPlugin.retry]
K --> C
J -->|no| L[结束该 URL]
H --> B
插件注册顺序可视化(静态图)
classDiagram
class EventBus {
+on(event, handler)
+emit(event, ...args)
}
class Plugin {
<<abstract>>
+register()
}
Plugin <|-- RetryPlugin
Plugin <|-- RateLimitPlugin
EventBus o-- Plugin : uses
电商订单管理系统
涵盖了商品管理、购物车、订单处理等常见业务场景,并使用了继承、多态、封装等核心 OOP 概念。
from dataclasses import dataclass, field
from datetime import datetime
from typing import List, Dict, Optional
from abc import ABC, abstractmethod
import uuid
# -------------------- 基础模型类 --------------------
class Entity(ABC):
"""抽象基类,所有实体类的父类"""
def __init__(self, id: str = None):
self.id = id or str(uuid.uuid4())
@abstractmethod
def display(self) -> str:
"""显示对象信息的抽象方法"""
pass
# -------------------- 商品相关类 --------------------
@dataclass
class Product(Entity):
"""商品类"""
name: str
price: float
description: str = ""
stock: int = 0
category: str = "uncategorized"
def display(self) -> str:
return f"{self.name} - ${self.price:.2f} ({self.stock} in stock)"
def reduce_stock(self, quantity: int) -> None:
"""减少库存"""
if quantity > self.stock:
raise ValueError("Insufficient stock")
self.stock -= quantity
# -------------------- 用户相关类 --------------------
@dataclass
class User(Entity):
"""用户基类"""
name: str
email: str
address: str = ""
def display(self) -> str:
return f"{self.name} <{self.email}>"
@dataclass
class Customer(User):
"""顾客类"""
loyalty_points: int = 0
cart: 'ShoppingCart' = field(default_factory=lambda: ShoppingCart())
def add_loyalty_points(self, points: int) -> None:
"""增加忠诚度积分"""
self.loyalty_points += points
@dataclass
class Admin(User):
"""管理员类"""
access_level: str = "basic"
# -------------------- 购物车和订单类 --------------------
@dataclass
class CartItem:
"""购物车项"""
product: Product
quantity: int = 1
@property
def subtotal(self) -> float:
return self.product.price * self.quantity
@dataclass
class ShoppingCart:
"""购物车"""
items: List[CartItem] = field(default_factory=list)
def add_item(self, product: Product, quantity: int = 1) -> None:
"""添加商品到购物车"""
for item in self.items:
if item.product.id == product.id:
item.quantity += quantity
return
self.items.append(CartItem(product, quantity))
def remove_item(self, product_id: str) -> None:
"""从购物车移除商品"""
self.items = [item for item in self.items if item.product.id != product_id]
@property
def total(self) -> float:
"""计算购物车总价"""
return sum(item.subtotal for item in self.items)
def clear(self) -> None:
"""清空购物车"""
self.items = []
@dataclass
class Order(Entity):
"""订单类"""
customer: Customer
items: List[CartItem]
order_date: datetime = field(default_factory=datetime.now)
status: str = "pending"
shipping_address: str = ""
def display(self) -> str:
return f"Order #{self.id[:8]} - {self.status} - Total: ${self.total:.2f}"
@property
def total(self) -> float:
return sum(item.subtotal for item in self.items)
def apply_discount(self, discount_percent: float) -> None:
"""应用折扣"""
if discount_percent < 0 or discount_percent > 100:
raise ValueError("Discount must be between 0 and 100")
# 在实际应用中,这里会有更复杂的折扣逻辑
for item in self.items:
item.product.price *= (1 - discount_percent / 100)
# -------------------- 支付相关类 --------------------
class PaymentMethod(ABC):
"""支付方式抽象基类"""
@abstractmethod
def process_payment(self, amount: float) -> bool:
pass
class CreditCardPayment(PaymentMethod):
"""信用卡支付"""
def __init__(self, card_number: str, expiry_date: str, cvv: str):
self.card_number = card_number
self.expiry_date = expiry_date
self.cvv = cvv
def process_payment(self, amount: float) -> bool:
print(f"Processing credit card payment of ${amount:.2f}")
# 模拟支付处理
return True
class PayPalPayment(PaymentMethod):
"""PayPal支付"""
def __init__(self, email: str):
self.email = email
def process_payment(self, amount: float) -> bool:
print(f"Processing PayPal payment of ${amount:.2f} from {self.email}")
# 模拟支付处理
return True
# -------------------- 订单服务类 --------------------
class OrderService:
"""订单服务类,处理订单相关业务逻辑"""
def __init__(self):
self.orders: Dict[str, Order] = {}
def create_order(self, customer: Customer) -> Order:
"""从购物车创建订单"""
if not customer.cart.items:
raise ValueError("Cannot create order from empty cart")
order = Order(
customer=customer,
items=customer.cart.items.copy(),
shipping_address=customer.address
)
self.orders[order.id] = order
customer.cart.clear()
return order
def process_order(self, order: Order, payment_method: PaymentMethod) -> bool:
"""处理订单支付"""
if order.status != "pending":
raise ValueError("Order is already processed")
if payment_method.process_payment(order.total):
order.status = "completed"
# 减少库存
for item in order.items:
item.product.reduce_stock(item.quantity)
# 增加忠诚度积分 (每消费$1得1积分)
order.customer.add_loyalty_points(int(order.total))
return True
return False
def get_order(self, order_id: str) -> Optional[Order]:
"""获取订单"""
return self.orders.get(order_id)
# -------------------- 示例使用 --------------------
def main():
# 创建一些商品
laptop = Product(
name="MacBook Pro",
price=1999.99,
description="Apple M1 Pro, 16GB RAM, 512GB SSD",
stock=10,
category="Electronics"
)
phone = Product(
name="iPhone 13",
price=799.99,
description="A15 Bionic, 128GB",
stock=15,
category="Electronics"
)
book = Product(
name="Python Crash Course",
price=39.99,
description="Learn Python programming",
stock=50,
category="Books"
)
# 创建顾客
customer = Customer(
name="John Doe",
email="john@example.com",
address="123 Main St, Anytown, USA"
)
# 顾客添加商品到购物车
customer.cart.add_item(laptop)
customer.cart.add_item(phone, 2)
customer.cart.add_item(book)
print("Shopping Cart Contents:")
for item in customer.cart.items:
print(f"- {item.product.name} x{item.quantity}: ${item.subtotal:.2f}")
print(f"Total: ${customer.cart.total:.2f}")
# 创建订单
order_service = OrderService()
order = order_service.create_order(customer)
print(f"\nCreated {order.display()}")
# 处理支付
payment_method = CreditCardPayment(
card_number="4111111111111111",
expiry_date="12/25",
cvv="123"
)
if order_service.process_order(order, payment_method):
print("Payment processed successfully!")
print(f"Order status: {order.status}")
print(f"Customer loyalty points: {customer.loyalty_points}")
else:
print("Payment failed")
# 检查库存
print("\nUpdated Stock Levels:")
print(f"- {laptop.name}: {laptop.stock}")
print(f"- {phone.name}: {phone.stock}")
print(f"- {book.name}: {book.stock}")
if __name__ == "__main__":
main()