你有没有遇到过这种场景:让Agent做一件复杂的事,它做到一半就卡住了,或者做得乱七八糟?
比如:"帮我重构这个服务,提升性能"。这个任务够复杂了吧?它涉及代码分析、性能测试、方案设计、实现、部署、监控……一个Agent从头做到尾,几乎不可能做好。
Planning模式就是解决这个问题的:先把大任务拆解成小任务,再逐个执行。
听起来很简单,但90%的人都用错了。
最常见的错误是:Planner只是简单地把任务切成几块,没有考虑依赖关系、失败恢复、状态管理。结果就是:任务卡在第3步,你不知道第1、2步做了什么,也不知道从哪里恢复。
今天这篇,我不只告诉你Planning是什么,还要告诉你:
如何设计一个真正可用的Planner。
Planning的本质:分层决策
先看一个错误的例子。
错误做法:简单拆解,不考虑依赖
用户任务:帮我重构这个服务,提升性能 Planner Agent: 1. 优化数据库查询 2. 添加缓存 3. 优化算法 4. 部署 5. 监控 Worker Agent 1:执行优化数据库查询 Worker Agent 2:执行添加缓存 Worker Agent 3:执行优化算法 Worker Agent 4:执行部署 Worker Agent 5:执行监控 结果: - Worker Agent 1优化了查询,但没测试,不确定有没有问题 - Worker Agent 2添加了缓存,但Worker Agent 1的改动让缓存key设计不合理 - Worker Agent 3优化了算法,但和Worker Agent 2的缓存冲突了 - Worker Agent 4部署了,但前面的改动没验证,可能引入了bug - Worker Agent 5监控了,但不知道要监控什么指标
这个实现有什么问题?
第一,缺少依赖分析。步骤1、2、3互相影响,不应该并行执行。
第二,缺少验证环节。每一步做完后,需要验证是否达到了预期效果。
第三,缺少状态管理。如果步骤2失败了,你不知道步骤1做了什么,也不知道如何回滚。
正确做法:分层规划,考虑依赖和恢复
用户任务:帮我重构这个服务,提升性能 Planner Agent分析: 这是一个复杂任务,需要分层规划: 第一层:分析阶段(必须串行) - 步骤1:分析当前代码,找出性能瓶颈 - 步骤2:验证瓶颈,制定优化目标 - 步骤3:根据瓶颈类型,选择优化策略 第二层:优化阶段(部分并行,但有依赖) - 步骤4:优化数据库查询(依赖步骤3的决策) - 步骤5:添加缓存(依赖步骤4,需要知道查询逻辑) - 步骤6:优化算法(依赖步骤3,和步骤4、5并行) 第三层:验证阶段(必须串行) - 步骤7:性能测试,验证优化效果 - 步骤8:如果效果不达标,回到第二层调整 第四层:部署阶段(必须串行) - 步骤9:部署到测试环境 - 步骤10:监控指标,确认无异常 - 步骤11:部署到生产环境 第五层:监控阶段(持续执行) - 步骤12:持续监控,收集数据 - 步骤13:根据数据,决定是否需要进一步优化
看出来了吗?Planning的核心不是"切任务",而是分层决策。
分层决策有三个好处:
第一,降低复杂度。每层只关注自己的职责,不会互相干扰。
第二,提高可恢复性。如果某个步骤失败了,你知道它属于哪一层,也知道如何从这一层恢复。
第三,便于调试。你可以逐层验证,而不是一下子检查整个任务。
Planning的完整工作流程
Planning是一个五步流程:
分析→拆解→依赖分析→分配→执行→监控→恢复。
▪ Step 1:任务分析

Planner需要回答三个问题:
示例:
用户任务:帮我重构这个服务,提升性能 Planner分析: - 目标:提升服务性能 - 约束条件: * 不能破坏现有功能 * 需要可回滚 * 改动要可测试 - 成功标准: * 响应时间降低50% * 错误率不增加 * 资源占用不增加
▪ Step 2:任务拆解
Planner需要把大任务拆解成小任务。拆解的粒度要合适:
拆解原则:
示例:
任务拆解: 任务1:分析性能瓶颈 - 子任务1.1:分析代码逻辑 - 子任务1.2:分析数据库查询 - 子任务1.3:分析系统资源 任务2:设计优化方案 - 子任务2.1:数据库优化方案 - 子任务2.2:缓存策略设计 - 子任务2.3:算法优化方案 任务3:实现优化 - 子任务3.1:实现数据库优化 - 子任务3.2:实现缓存 - 子任务3.3:实现算法优化 任务4:验证效果 - 子任务4.1:性能测试 - 子任务4.2:回归测试 任务5:部署上线 - 子任务5.1:部署到测试环境 - 子任务5.2:部署到生产环境
▪ Step 3:依赖分析
这是Planning模式最关键的一步。你需要分析任务之间的依赖关系:
示例:
依赖关系分析: 强依赖链: 任务1(分析)→ 任务2(设计)→ 任务3(实现)→ 任务4(验证)→ 任务5(部署) 弱依赖: 子任务3.1(数据库优化)→ 子任务3.2(缓存) 子任务3.1(数据库优化) → 子任务3.3(算法优化) 无依赖: 子任务1.1、1.2、1.3可以并行 子任务2.1、2.2、2.3可以并行
▪ Step 4:任务分配
Planner需要把任务分配给Worker Agent。
分配时考虑:
示例:
Worker Agent配置: - CodeAnalyzer Worker:擅长代码分析 - DBWorker:擅长数据库操作 - CacheWorker:擅长缓存设计 - AlgorithmWorker:擅长算法优化 - TestWorker:擅长测试 - DeployWorker:擅长部署 任务分配: 任务1 → CodeAnalyzer Worker 任务2.1 → DBWorker 任务2.2 → CacheWorker 任务2.3 → AlgorithmWorker 任务3.1 → DBWorker 任务3.2 → CacheWorker 任务3.3 → AlgorithmWorker 任务4 → TestWorker 任务5 → DeployWorker
▪ Step 5:执行与监控

Worker Agent执行任务,Planner监控执行状态。
监控内容:
示例:
执行监控: 时间 10:00 - 任务1:执行中 - 任务2:待执行(依赖任务1) - 任务3:待执行(依赖任务2) - 任务4:待执行(依赖任务3) - 任务5:待执行(依赖任务4) 时间 10:15 - 任务1:已完成(耗时15分钟) - 任务2:执行中 - 任务3:待执行(依赖任务2) - 任务4:待执行(依赖任务3) - 任务5:待执行(依赖任务4) 时间 10:45 - 任务1:已完成 - 任务2:已完成(耗时30分钟) - 任务3.1:执行中 - 任务3.2:待执行(依赖3.1) - 任务3.3:执行中(和3.1并行) - 任务4:待执行(依赖3) - 任务5:待执行(依赖任务4)
▪ Step 6:失败恢复
这是Planning模式最重要的功能。如果某个任务失败了,Planner需要决定如何恢复:
示例:
失败恢复场景: 场景1:任务3.1(数据库优化)失败 - 判断:这是关键任务,必须完成 - 操作:重试3次,如果还是失败,回滚任务2.1的改动,从任务2.1重新开始 场景2:任务3.2(缓存)失败 - 判断:这是非关键任务,有降级方案 - 操作:跳过任务3.2,继续执行任务3.3,最后提醒用户缓存未添加 场景3:任务4(验证)失败 - 判断:这是关键任务,必须通过 - 操作:终止整个流程,分析失败原因,决定是否重新从任务3开始
任务树结构
Planning模式的核心数据结构是任务树。任务树是一个有向无环图(DAG),每个节点是一个任务,边表示依赖关系。
▪ 任务树的数据结构
from dataclasses import dataclass from typing import List, Optional from enum import Enum class TaskStatus(Enum): PENDING = "pending" RUNNING = "running" COMPLETED = "completed" FAILED = "failed" SKIPPED = "skipped" class TaskPriority(Enum): HIGH = "high" MEDIUM = "medium" LOW = "low" @dataclass class Task: id: str name: str description: str status: TaskStatus = TaskStatus.PENDING priority: TaskPriority = TaskPriority.MEDIUM dependencies: List[str] = None worker_type: str = None result: Optional[dict] = None error: Optional[str] = None start_time: Optional[float] = None end_time: Optional[float] = None retry_count: int = 0 max_retries: int = 3 def __post_init__(self): if self.dependencies is None: self.dependencies = [] @dataclass class TaskNode: task: Task children: List['TaskNode'] = None parent: Optional['TaskNode'] = None def __post_init__(self): if self.children is None: self.children = [] class TaskTree: def __init__(self): self.tasks: dict[str, Task] = {} self.root: Optional[TaskNode] = None def add_task(self, task: Task): self.tasks[task.id] = task def get_task(self, task_id: str) -> Optional[Task]: return self.tasks.get(task_id) def get_ready_tasks(self) -> List[Task]: """获取所有依赖已满足的任务""" ready_tasks = [] for task in self.tasks.values(): if task.status == TaskStatus.PENDING: dependencies_met = all( self.get_task(dep_id).status == TaskStatus.COMPLETED for dep_id in task.dependencies ) if dependencies_met: ready_tasks.append(task) return ready_tasks def get_failed_tasks(self) -> List[Task]: """获取所有失败的任务""" return [task for task in self.tasks.values() if task.status == TaskStatus.FAILED] def get_completed_tasks(self) -> List[Task]: """获取所有已完成的任务""" return [task for task in self.tasks.values() if task.status == TaskStatus.COMPLETED] def is_complete(self) -> bool: """检查所有任务是否都已完成(包括跳过)""" return all( task.status in [TaskStatus.COMPLETED, TaskStatus.SKIPPED] for task in self.tasks.values() ) def has_failed(self) -> bool: """检查是否有任务失败""" return any(task.status == TaskStatus.FAILED for task in self.tasks.values())
▪ 使用任务树的示例
# 创建任务树 tree = TaskTree() # 添加任务 tree.add_task(Task( id="1", name="分析性能瓶颈", description="分析代码、数据库、系统资源,找出性能瓶颈", worker_type="code_analyzer", priority=TaskPriority.HIGH )) tree.add_task(Task( id="2.1", name="设计数据库优化方案", description="根据瓶颈分析结果,设计数据库优化方案", dependencies=["1"], worker_type="db_worker", priority=TaskPriority.HIGH )) tree.add_task(Task( id="2.2", name="设计缓存策略", description="设计缓存key、失效策略、更新策略", dependencies=["1"], worker_type="cache_worker", priority=TaskPriority.MEDIUM )) tree.add_task(Task( id="3.1", name="实现数据库优化", description="实现数据库查询优化、索引优化", dependencies=["2.1"], worker_type="db_worker", priority=TaskPriority.HIGH )) tree.add_task(Task( id="3.2", name="实现缓存", description="实现Redis缓存", dependencies=["3.1"], worker_type="cache_worker", priority=TaskPriority.MEDIUM )) tree.add_task(Task( id="4", name="验证效果", description="进行性能测试和回归测试", dependencies=["3.1", "3.2"], worker_type="test_worker", priority=TaskPriority.HIGH )) # 获取可以执行的任务 ready_tasks = tree.get_ready_tasks() print(f"可以执行的任务:{[t.name for t in ready_tasks]}") # 输出:可以执行的任务:['分析性能瓶颈'] # 模拟任务1完成 task1 = tree.get_task("1") task1.status = TaskStatus.COMPLETED task1.result = {"bottlenecks": ["database", "algorithm"]} # 再次获取可以执行的任务 ready_tasks = tree.get_ready_tasks() print(f"可以执行的任务:{[t.name for t in ready_tasks]}") # 输出:可以执行的任务:['设计数据库优化方案', '设计缓存策略']
Planner与Worker的协作
Planning模式需要Planner和Worker之间的清晰协作机制。协作的核心是通信协议和状态同步。
▪ 通信协议
Planner和Worker之间的通信应该包括以下内容:
消息格式示例:
from dataclasses import dataclass from typing import Any from enum import Enum class MessageType(Enum): ASSIGN_TASK = "assign_task" TASK_UPDATE = "task_update" TASK_RESULT = "task_result" TASK_FAILURE = "task_failure" RECOVERY_COMMAND = "recovery_command" @dataclass class Message: type: MessageType task_id: str sender: str receiver: str payload: dict timestamp: float # Planner发送任务分配 assign_message = Message( type=MessageType.ASSIGN_TASK, task_id="3.1", sender="planner", receiver="db_worker", payload={ "task_name": "实现数据库优化", "description": "实现数据库查询优化、索引优化", "input_data": { "bottlenecks": ["database", "algorithm"], "optimization_plan": "..." } }, timestamp=time.time() ) # Worker发送状态更新 update_message = Message( type=MessageType.TASK_UPDATE, task_id="3.1", sender="db_worker", receiver="planner", payload={ "status": "running", "progress": 0.5, "current_step": "添加索引" }, timestamp=time.time() ) # Worker发送结果报告 result_message = Message( type=MessageType.TASK_RESULT, task_id="3.1", sender="db_worker", receiver="planner", payload={ "status": "completed", "result": { "optimizations": ["添加索引", "优化查询"], "performance_improvement": "40%" } }, timestamp=time.time() ) # Worker发送失败通知 failure_message = Message( type=MessageType.TASK_FAILURE, task_id="3.2", sender="cache_worker", receiver="planner", payload={ "error": "Redis连接失败", "error_type": "ConnectionError", "retry_count": 3, "can_recover": True }, timestamp=time.time() ) # Planner发送恢复指令 recovery_message = Message( type=MessageType.RECOVERY_COMMAND, task_id="3.2", sender="planner", receiver="cache_worker", payload={ "command": "skip", "reason": "缓存是可选功能,跳过不影响核心优化", "next_task": "4" }, timestamp=time.time() )
▪ 状态同步
Planner需要维护整个任务树的状态,Worker需要维护自己执行的任务状态。
状态同步的时机:
状态同步的实现:
class Planner: def __init__(self, task_tree: TaskTree): self.task_tree = task_tree self.message_queue = [] self.workers = {} def register_worker(self, worker_id: str, worker_type: str): """注册Worker""" self.workers[worker_id] = { "type": worker_type, "status": "idle", "current_task": None } def assign_task(self, task: Task): """分配任务给合适的Worker""" # 找到合适的Worker suitable_worker = self._find_suitable_worker(task.worker_type) if not suitable_worker: raise ValueError(f"没有找到类型为 {task.worker_type} 的Worker") # 更新Worker状态 self.workers[suitable_worker]["status"] = "busy" self.workers[suitable_worker]["current_task"] = task.id # 更新任务状态 task.status = TaskStatus.RUNNING task.start_time = time.time() # 发送任务分配消息 message = Message( type=MessageType.ASSIGN_TASK, task_id=task.id, sender="planner", receiver=suitable_worker, payload={ "task_name": task.name, "description": task.description, "input_data": self._get_task_input(task) }, timestamp=time.time() ) self._send_message(message) def handle_message(self, message: Message): """处理Worker发来的消息""" if message.type == MessageType.TASK_UPDATE: self._handle_task_update(message) elif message.type == MessageType.TASK_RESULT: self._handle_task_result(message) elif message.type == MessageType.TASK_FAILURE: self._handle_task_failure(message) def _handle_task_update(self, message: Message): """处理任务更新""" task = self.task_tree.get_task(message.task_id) if task: # 记录进度 task.progress = message.payload.get("progress", 0) task.current_step = message.payload.get("current_step", "") # 检查是否有新任务可以执行 ready_tasks = self.task_tree.get_ready_tasks() for ready_task in ready_tasks: self.assign_task(ready_task) def _handle_task_result(self, message: Message): """处理任务结果""" task = self.task_tree.get_task(message.task_id) if task: task.status = TaskStatus.COMPLETED task.end_time = time.time() task.result = message.payload.get("result", {}) # 释放Worker worker_id = message.sender self.workers[worker_id]["status"] = "idle" self.workers[worker_id]["current_task"] = None # 检查是否有新任务可以执行 ready_tasks = self.task_tree.get_ready_tasks() for ready_task in ready_tasks: self.assign_task(ready_task) def _handle_task_failure(self, message: Message): """处理任务失败""" task = self.task_tree.get_task(message.task_id) if task: task.status = TaskStatus.FAILED task.end_time = time.time() task.error = message.payload.get("error", "") # 决定恢复策略 recovery_strategy = self._decide_recovery(task, message.payload) if recovery_strategy == "retry": self._retry_task(task) elif recovery_strategy == "skip": self._skip_task(task) elif recovery_strategy == "rollback": self._rollback_task(task) elif recovery_strategy == "terminate": self._terminate_workflow() def _decide_recovery(self, task: Task, failure_info: dict) -> str: """决定恢复策略""" # 如果任务可以重试,且未超过最大重试次数 if failure_info.get("can_recover", False) and task.retry_count < task.max_retries: return "retry" # 如果是低优先级任务,可以跳过 if task.priority == TaskPriority.LOW: return "skip" # 如果是高优先级任务,且无法恢复,需要回滚 if task.priority == TaskPriority.HIGH: return "rollback" # 默认终止 return "terminate"
状态管理与恢复
状态管理是Planning模式最复杂的部分。你需要考虑:
▪ 状态持久化
import json import os from typing import Optional class StateManager: def __init__(self, state_file: str): self.state_file = state_file self.state = self._load_state() def _load_state(self) -> dict: """从文件加载状态""" if os.path.exists(self.state_file): with open(self.state_file, 'r', encoding='utf-8') as f: return json.load(f) return { "tasks": {}, "workers": {}, "workflow_status": "idle", "created_at": time.time() } def _save_state(self): """保存状态到文件""" with open(self.state_file, 'w', encoding='utf-8') as f: json.dump(self.state, f, indent=2, ensure_ascii=False) def update_task_state(self, task_id: str, task_state: dict): """更新任务状态""" self.state["tasks"][task_id] = task_state self._save_state() def get_task_state(self, task_id: str) -> Optional[dict]: """获取任务状态""" return self.state["tasks"].get(task_id) def update_workflow_status(self, status: str): """更新工作流状态""" self.state["workflow_status"] = status self._save_state() def get_workflow_status(self) -> str: """获取工作流状态""" return self.state.get("workflow_status", "idle") def reset_state(self): """重置状态""" self.state = { "tasks": {}, "workers": {}, "workflow_status": "idle", "created_at": time.time() } self._save_state()
▪ 状态回滚
class RollbackManager: def __init__(self, state_manager: StateManager): self.state_manager = state_manager self.checkpoints = {} def create_checkpoint(self, task_id: str, state: dict): """创建检查点""" self.checkpoints[task_id] = { "state": state.copy(), "timestamp": time.time() } def rollback_to_checkpoint(self, task_id: str): """回滚到检查点""" if task_id not in self.checkpoints: raise ValueError(f"任务 {task_id} 没有检查点") checkpoint = self.checkpoints[task_id] # 恢复任务状态 self.state_manager.update_task_state(task_id, checkpoint["state"]) # 删除检查点 del self.checkpoints[task_id] def rollback_dependent_tasks(self, task_id: str, task_tree: TaskTree): """回滚依赖此任务的所有任务""" task = task_tree.get_task(task_id) if not task: return # 找出所有依赖此任务的任务 dependent_tasks = [] for t in task_tree.tasks.values(): if task_id in t.dependencies: dependent_tasks.append(t) # 递归回滚 for dep_task in dependent_tasks: # 重置任务状态 dep_task.status = TaskStatus.PENDING dep_task.result = None dep_task.error = None dep_task.start_time = None dep_task.end_time = None # 继续回滚依赖此任务的任务 self.rollback_dependent_tasks(dep_task.id, task_tree)
▪ 状态恢复
class RecoveryManager: def __init__(self, state_manager: StateManager, task_tree: TaskTree): self.state_manager = state_manager self.task_tree = task_tree def recover_workflow(self) -> dict: """恢复工作流""" workflow_status = self.state_manager.get_workflow_status() if workflow_status == "idle": return {"status": "idle", "message": "工作流未开始"} if workflow_status == "completed": return {"status": "completed", "message": "工作流已完成"} if workflow_status == "failed": return self._recover_failed_workflow() if workflow_status == "running": return self._recover_running_workflow() return {"status": "unknown", "message": "未知的工作流状态"} def _recover_failed_workflow(self) -> dict: """恢复失败的工作流""" failed_tasks = self.task_tree.get_failed_tasks() if not failed_tasks: return {"status": "completed", "message": "没有失败的任务"} # 找出需要重试的任务 retry_tasks = [] for task in failed_tasks: task_state = self.state_manager.get_task_state(task.id) if task_state and task_state.get("retry_count", 0) < task.max_retries: retry_tasks.append(task) if not retry_tasks: return {"status": "cannot_recover", "message": "没有可以重试的任务"} # 重置任务状态 for task in retry_tasks: task.status = TaskStatus.PENDING task.error = None task.retry_count += 1 return { "status": "recovered", "message": f"已恢复 {len(retry_tasks)} 个任务", "retry_tasks": [t.id for t in retry_tasks] } def _recover_running_workflow(self) -> dict: """恢复运行中的工作流""" running_tasks = [ task for task in self.task_tree.tasks.values() if task.status == TaskStatus.RUNNING ] if not running_tasks: return {"status": "idle", "message": "没有运行中的任务"} # 重置运行中的任务为待执行 for task in running_tasks: task.status = TaskStatus.PENDING task.start_time = None return { "status": "recovered", "message": f"已重置 {len(running_tasks)} 个运行中的任务", "reset_tasks": [t.id for t in running_tasks] }
实战:实现一个完整的Planner
好了,理论讲够了。现在我们来手写一个完整的Planner系统。
▪ 完整代码
import time import json import os from dataclasses import dataclass, asdict from typing import List, Optional, Dict, Callable from enum import Enum from concurrent.futures import ThreadPoolExecutor, Future # ==================== 数据结构 ==================== class TaskStatus(Enum): PENDING = "pending" RUNNING = "running" COMPLETED = "completed" FAILED = "failed" SKIPPED = "skipped" class TaskPriority(Enum): HIGH = "high" MEDIUM = "medium" LOW = "low" class MessageType(Enum): ASSIGN_TASK = "assign_task" TASK_UPDATE = "task_update" TASK_RESULT = "task_result" TASK_FAILURE = "task_failure" RECOVERY_COMMAND = "recovery_command" @dataclass class Task: id: str name: str description: str status: TaskStatus = TaskStatus.PENDING priority: TaskPriority = TaskPriority.MEDIUM dependencies: List[str] = None worker_type: str = None result: Optional[dict] = None error: Optional[str] = None progress: float = 0.0 current_step: str = "" start_time: Optional[float] = None end_time: Optional[float] = None retry_count: int = 0 max_retries: int = 3 input_data: dict = None def __post_init__(self): if self.dependencies is None: self.dependencies = [] if self.input_data is None: self.input_data = {} def to_dict(self): return asdict(self) @classmethod def from_dict(cls, data: dict): data = data.copy() data['status'] = TaskStatus(data['status']) data['priority'] = TaskPriority(data['priority']) return cls(**data) @dataclass class Message: type: MessageType task_id: str sender: str receiver: str payload: dict timestamp: float def to_dict(self): data = asdict(self) data['type'] = self.type.value return data # ==================== 任务树 ==================== class TaskTree: def __init__(self): self.tasks: Dict[str, Task] = {} def add_task(self, task: Task): self.tasks[task.id] = task def get_task(self, task_id: str) -> Optional[Task]: return self.tasks.get(task_id) def get_ready_tasks(self) -> List[Task]: """获取所有依赖已满足的任务""" ready_tasks = [] for task in self.tasks.values(): if task.status == TaskStatus.PENDING: dependencies_met = all( self.get_task(dep_id).status == TaskStatus.COMPLETED for dep_id in task.dependencies if dep_id in self.tasks ) if dependencies_met: ready_tasks.append(task) return ready_tasks def get_failed_tasks(self) -> List[Task]: """获取所有失败的任务""" return [task for task in self.tasks.values() if task.status == TaskStatus.FAILED] def get_completed_tasks(self) -> List[Task]: """获取所有已完成的任务""" return [task for task in self.tasks.values() if task.status == TaskStatus.COMPLETED] def is_complete(self) -> bool: """检查所有任务是否都已完成(包括跳过)""" return all( task.status in [TaskStatus.COMPLETED, TaskStatus.SKIPPED] for task in self.tasks.values() ) def has_failed(self) -> bool: """检查是否有任务失败""" return any(task.status == TaskStatus.FAILED for task in self.tasks.values()) # ==================== 状态管理 ==================== class StateManager: def __init__(self, state_file: str): self.state_file = state_file self.state = self._load_state() def _load_state(self) -> dict: """从文件加载状态""" if os.path.exists(self.state_file): with open(self.state_file, 'r', encoding='utf-8') as f: return json.load(f) return { "tasks": {}, "workflow_status": "idle", "created_at": time.time() } def _save_state(self): """保存状态到文件""" with open(self.state_file, 'w', encoding='utf-8') as f: json.dump(self.state, f, indent=2, ensure_ascii=False) def update_task_state(self, task_id: str, task: Task): """更新任务状态""" self.state["tasks"][task_id] = task.to_dict() self._save_state() def get_task_state(self, task_id: str) -> Optional[dict]: """获取任务状态""" return self.state["tasks"].get(task_id) def update_workflow_status(self, status: str): """更新工作流状态""" self.state["workflow_status"] = status self._save_state() def get_workflow_status(self) -> str: """获取工作流状态""" return self.state.get("workflow_status", "idle") def reset_state(self): """重置状态""" self.state = { "tasks": {}, "workflow_status": "idle", "created_at": time.time() } self._save_state() # ==================== Worker ==================== class Worker: def __init__(self, worker_id: str, worker_type: str, executor: ThreadPoolExecutor): self.worker_id = worker_id self.worker_type = worker_type self.executor = executor self.status = "idle" self.current_task: Optional[Task] = None self.message_handler: Optional[Callable] = None def set_message_handler(self, handler: Callable): """设置消息处理器""" self.message_handler = handler def receive_message(self, message: Message): """接收消息""" if message.type == MessageType.ASSIGN_TASK: self._handle_assign_task(message) def _handle_assign_task(self, message: Message): """处理任务分配""" task_data = message.payload self.current_task = Task( id=message.task_id, name=task_data["task_name"], description=task_data["description"], input_data=task_data.get("input_data", {}), worker_type=self.worker_type ) self.status = "busy" # 异步执行任务 future = self.executor.submit(self._execute_task, self.current_task) future.add_done_callback(self._on_task_complete) def _execute_task(self, task: Task): """执行任务(由子类实现)""" raise NotImplementedError("子类必须实现_execute_task方法") def _on_task_complete(self, future: Future): """任务完成回调""" try: result = future.result() self.current_task.status = TaskStatus.COMPLETED self.current_task.result = result self.current_task.end_time = time.time() # 发送完成消息 message = Message( type=MessageType.TASK_RESULT, task_id=self.current_task.id, sender=self.worker_id, receiver="planner", payload={ "status": "completed", "result": result }, timestamp=time.time() ) except Exception as e: self.current_task.status = TaskStatus.FAILED self.current_task.error = str(e) self.current_task.end_time = time.time() # 发送失败消息 message = Message( type=MessageType.TASK_FAILURE, task_id=self.current_task.id, sender=self.worker_id, receiver="planner", payload={ "error": str(e), "error_type": type(e).__name__, "retry_count": self.current_task.retry_count, "can_recover": True }, timestamp=time.time() ) # 通知Planner if self.message_handler: self.message_handler(message) self.status = "idle" self.current_task = None # ==================== 具体Worker实现 ==================== class CodeAnalyzerWorker(Worker): def _execute_task(self, task: Task): """执行代码分析任务""" print(f"[{self.worker_id}] 开始分析代码...") # 模拟分析过程 time.sleep(2) # 模拟发现瓶颈 bottlenecks = ["database", "algorithm"] print(f"[{self.worker_id}] 代码分析完成,发现瓶颈:{bottlenecks}") return { "bottlenecks": bottlenecks, "analysis_summary": "发现数据库查询和算法是主要瓶颈" } class DBWorker(Worker): def _execute_task(self, task: Task): """执行数据库任务""" print(f"[{self.worker_id}] 开始执行数据库任务:{task.name}") # 模拟执行过程 time.sleep(3) # 根据任务类型执行不同操作 if "设计" in task.name: result = { "optimization_plan": "添加索引、优化查询语句", "expected_improvement": "30%" } elif "实现" in task.name: result = { "optimizations": ["添加索引", "优化查询"], "performance_improvement": "40%" } else: result = {} print(f"[{self.worker_id}] 数据库任务完成:{task.name}") return result class CacheWorker(Worker): def _execute_task(self, task: Task): """执行缓存任务""" print(f"[{self.worker_id}] 开始执行缓存任务:{task.name}") # 模拟执行过程 time.sleep(2) # 模拟失败场景(仅用于演示) if "实现" in task.name and task.retry_count == 0: raise ConnectionError("Redis连接失败") result = { "cache_strategy": "Redis缓存", "cache_keys": ["user:*", "product:*"], "ttl": 3600 } print(f"[{self.worker_id}] 缓存任务完成:{task.name}") return result class AlgorithmWorker(Worker): def _execute_task(self, task: Task): """执行算法任务""" print(f"[{self.worker_id}] 开始执行算法任务:{task.name}") # 模拟执行过程 time.sleep(4) result = { "optimizations": ["使用快速排序", "优化循环"], "time_complexity": "O(n log n)" } print(f"[{self.worker_id}] 算法任务完成:{task.name}") return result class TestWorker(Worker): def _execute_task(self, task: Task): """执行测试任务""" print(f"[{self.worker_id}] 开始执行测试任务:{task.name}") # 模拟测试过程 time.sleep(5) result = { "performance_test": { "before": "1000ms", "after": "600ms", "improvement": "40%" }, "regression_test": "通过", "all_tests_passed": True } print(f"[{self.worker_id}] 测试任务完成:{task.name}") return result # ==================== Planner ==================== class Planner: def __init__(self, task_tree: TaskTree, state_manager: StateManager, max_workers: int = 4): self.task_tree = task_tree self.state_manager = state_manager self.executor = ThreadPoolExecutor(max_workers=max_workers) self.workers: Dict[str, Worker] = {} self.message_queue = [] self.running = False self._initialize_workers() def _initialize_workers(self): """初始化Workers""" self.workers = { "code_analyzer": CodeAnalyzerWorker("code_analyzer", "code_analyzer", self.executor), "db_worker": DBWorker("db_worker", "db_worker", self.executor), "cache_worker": CacheWorker("cache_worker", "cache_worker", self.executor), "algorithm_worker": AlgorithmWorker("algorithm_worker", "algorithm_worker", self.executor), "test_worker": TestWorker("test_worker", "test_worker", self.executor), } # 设置消息处理器 for worker in self.workers.values(): worker.set_message_handler(self.handle_message) def start(self): """启动Planner""" self.running = True self.state_manager.update_workflow_status("running") print("[Planner] 工作流启动") # 开始调度任务 self._schedule_tasks() def stop(self): """停止Planner""" self.running = False self.state_manager.update_workflow_status("stopped") print("[Planner] 工作流停止") def _schedule_tasks(self): """调度任务""" while self.running and not self.task_tree.is_complete() and not self.task_tree.has_failed(): ready_tasks = self.task_tree.get_ready_tasks() if not ready_tasks: # 没有可以执行的任务,等待一下 time.sleep(0.1) continue # 分配任务 for task in ready_tasks: self._assign_task(task) # 等待任务完成 time.sleep(0.1) # 检查最终状态 if self.task_tree.is_complete(): print("[Planner] 所有任务已完成!") self.state_manager.update_workflow_status("completed") elif self.task_tree.has_failed(): print("[Planner] 工作流失败!") self.state_manager.update_workflow_status("failed") def _assign_task(self, task: Task): """分配任务给合适的Worker""" # 找到合适的Worker suitable_worker = self._find_suitable_worker(task.worker_type) if not suitable_worker: print(f"[Planner] 警告:没有找到类型为 {task.worker_type} 的Worker") return # 更新任务状态 task.status = TaskStatus.RUNNING task.start_time = time.time() self.state_manager.update_task_state(task.id, task) # 发送任务分配消息 message = Message( type=MessageType.ASSIGN_TASK, task_id=task.id, sender="planner", receiver=suitable_worker.worker_id, payload={ "task_name": task.name, "description": task.description, "input_data": task.input_data }, timestamp=time.time() ) suitable_worker.receive_message(message) print(f"[Planner] 已分配任务 '{task.name}' 给 {suitable_worker.worker_id}") def _find_suitable_worker(self, worker_type: str) -> Optional[Worker]: """找到合适的Worker""" for worker in self.workers.values(): if worker.worker_type == worker_type and worker.status == "idle": return worker return None def handle_message(self, message: Message): """处理Worker发来的消息""" if message.type == MessageType.TASK_RESULT: self._handle_task_result(message) elif message.type == MessageType.TASK_FAILURE: self._handle_task_failure(message) def _handle_task_result(self, message: Message): """处理任务结果""" task = self.task_tree.get_task(message.task_id) if task: task.status = TaskStatus.COMPLETED task.end_time = time.time() task.result = message.payload.get("result", {}) self.state_manager.update_task_state(message.task_id, task) print(f"[Planner] 任务 '{task.name}' 完成") # 传递结果给依赖任务 self._propagate_result(task) def _handle_task_failure(self, message: Message): """处理任务失败""" task = self.task_tree.get_task(message.task_id) if task: task.status = TaskStatus.FAILED task.end_time = time.time() task.error = message.payload.get("error", "") self.state_manager.update_task_state(message.task_id, task) print(f"[Planner] 任务 '{task.name}' 失败:{task.error}") # 决定恢复策略 recovery_strategy = self._decide_recovery(task, message.payload) if recovery_strategy == "retry": self._retry_task(task) elif recovery_strategy == "skip": self._skip_task(task) elif recovery_strategy == "rollback": print(f"[Planner] 决定回滚任务 '{task.name}'") # 这里可以实现回滚逻辑 elif recovery_strategy == "terminate": print(f"[Planner] 决定终止工作流") self.stop() def _decide_recovery(self, task: Task, failure_info: dict) -> str: """决定恢复策略""" # 如果任务可以重试,且未超过最大重试次数 if failure_info.get("can_recover", False) and task.retry_count < task.max_retries: return "retry" # 如果是低优先级任务,可以跳过 if task.priority == TaskPriority.LOW: return "skip" # 如果是中优先级任务,且不是关键路径,可以跳过 if task.priority == TaskPriority.MEDIUM: # 检查是否有依赖此任务的高优先级任务 has_high_priority_dependent = any( self.task_tree.get_task(dep_id).priority == TaskPriority.HIGH for t in self.task_tree.tasks.values() for dep_id in t.dependencies if dep_id == task.id ) if not has_high_priority_dependent: return "skip" # 默认终止 return "terminate" def _retry_task(self, task: Task): """重试任务""" task.retry_count += 1 task.status = TaskStatus.PENDING task.error = None task.start_time = None task.end_time = None print(f"[Planner] 重试任务 '{task.name}' (第{task.retry_count}次)") self.state_manager.update_task_state(task.id, task) def _skip_task(self, task: Task): """跳过任务""" task.status = TaskStatus.SKIPPED task.end_time = time.time() print(f"[Planner] 跳过任务 '{task.name}'") self.state_manager.update_task_state(task.id, task) def _propagate_result(self, task: Task): """传递结果给依赖任务""" for t in self.task_tree.tasks.values(): if task.id in t.dependencies: # 将当前任务的结果传递给依赖任务 if task.result: t.input_data[f"task_{task.id}_result"] = task.result self.state_manager.update_task_state(t.id, t) def get_workflow_summary(self) -> dict: """获取工作流摘要""" return { "total_tasks": len(self.task_tree.tasks), "completed_tasks": len(self.task_tree.get_completed_tasks()), "failed_tasks": len(self.task_tree.get_failed_tasks()), "status": self.state_manager.get_workflow_status(), "tasks": { task_id: { "name": task.name, "status": task.status.value, "result": task.result } for task_id, task in self.task_tree.tasks.items() } } # ==================== 主程序 ==================== def create_optimization_workflow() -> TaskTree: """创建性能优化工作流""" tree = TaskTree() # 第一层:分析阶段 tree.add_task(Task( id="1", name="分析性能瓶颈", description="分析代码、数据库、系统资源,找出性能瓶颈", worker_type="code_analyzer", priority=TaskPriority.HIGH )) # 第二层:设计阶段 tree.add_task(Task( id="2.1", name="设计数据库优化方案", description="根据瓶颈分析结果,设计数据库优化方案", dependencies=["1"], worker_type="db_worker", priority=TaskPriority.HIGH )) tree.add_task(Task( id="2.2", name="设计缓存策略", description="设计缓存key、失效策略、更新策略", dependencies=["1"], worker_type="cache_worker", priority=TaskPriority.MEDIUM )) tree.add_task(Task( id="2.3", name="设计算法优化方案", description="根据瓶颈分析结果,设计算法优化方案", dependencies=["1"], worker_type="algorithm_worker", priority=TaskPriority.HIGH )) # 第三层:实现阶段 tree.add_task(Task( id="3.1", name="实现数据库优化", description="实现数据库查询优化、索引优化", dependencies=["2.1"], worker_type="db_worker", priority=TaskPriority.HIGH )) tree.add_task(Task( id="3.2", name="实现缓存", description="实现Redis缓存", dependencies=["2.2"], worker_type="cache_worker", priority=TaskPriority.MEDIUM )) tree.add_task(Task( id="3.3", name="实现算法优化", description="实现算法优化", dependencies=["2.3"], worker_type="algorithm_worker", priority=TaskPriority.HIGH )) # 第四层:验证阶段 tree.add_task(Task( id="4", name="验证效果", description="进行性能测试和回归测试", dependencies=["3.1", "3.2", "3.3"], worker_type="test_worker", priority=TaskPriority.HIGH )) return tree def main(): """主函数""" print("=" * 60) print("Agent设计模式实战:Planning模式") print("=" * 60) print() # 创建任务树 task_tree = create_optimization_workflow() print(f"已创建工作流,共 {len(task_tree.tasks)} 个任务") print() # 创建状态管理器 state_manager = StateManager("workflow_state.json") state_manager.reset_state() # 创建Planner planner = Planner(task_tree, state_manager) # 启动工作流 planner.start() # 等待工作流完成 while planner.running: time.sleep(0.5) # 输出摘要 print() print("=" * 60) print("工作流摘要") print("=" * 60) summary = planner.get_workflow_summary() print(json.dumps(summary, indent=2, ensure_ascii=False)) # 关闭线程池 planner.executor.shutdown() if __name__ == "__main__": main()
▪ 代码解析
这个实现包含了Planning模式的所有核心要素:
关键设计点:
get_ready_tasks()方法自动检查依赖是否满足▪ 运行示例
运行这个程序,你会看到类似这样的输出:
============================================================ Agent设计模式实战:Planning模式 ============================================================ 已创建工作流,共 8 个任务 [Planner] 工作流启动 [Planner] 已分配任务 '分析性能瓶颈' 给 code_analyzer [code_analyzer] 开始分析代码... [code_analyzer] 代码分析完成,发现瓶颈:['database', 'algorithm'] [Planner] 任务 '分析性能瓶颈' 完成 [Planner] 已分配任务 '设计数据库优化方案' 给 db_worker [Planner] 已分配任务 '设计缓存策略' 给 cache_worker [Planner] 已分配任务 '设计算法优化方案' 给 algorithm_worker [db_worker] 开始执行数据库任务:设计数据库优化方案 [cache_worker] 开始执行缓存任务:设计缓存策略 [algorithm_worker] 开始执行算法任务:设计算法优化方案 [db_worker] 数据库任务完成:设计数据库优化方案 [cache_worker] 缓存任务完成:设计缓存策略 [algorithm_worker] 算法任务完成:设计算法优化方案 [Planner] 任务 '设计数据库优化方案' 完成 [Planner] 任务 '设计缓存策略' 完成 [Planner] 任务 '设计算法优化方案' 完成 [Planner] 已分配任务 '实现数据库优化' 给 db_worker [cache_worker] 开始执行缓存任务:实现缓存 [cache_worker] 缓存任务失败:Redis连接失败 [Planner] 任务 '实现缓存' 失败:Redis连接失败 [Planner] 重试任务 '实现缓存' (第1次) [Planner] 已分配任务 '实现缓存' 给 cache_worker [cache_worker] 开始执行缓存任务:实现缓存 [cache_worker] 缓存任务完成:实现缓存 [Planner] 任务 '实现缓存' 完成 [db_worker] 数据库任务完成:实现数据库优化 [algorithm_worker] 算法任务完成:实现算法优化 [Planner] 任务 '实现数据库优化' 完成 [Planner] 任务 '实现算法优化' 完成 [Planner] 已分配任务 '验证效果' 给 test_worker [test_worker] 开始执行测试任务:验证效果 [test_worker] 测试任务完成:验证效果 [Planner] 任务 '验证效果' 完成 [Planner] 所有任务已完成! ============================================================ 工作流摘要 ============================================================ { "total_tasks": 8, "completed_tasks": 8, "failed_tasks": 0, "status": "completed", "tasks": { "1": { "name": "分析性能瓶颈", "status": "completed", "result": { "bottlenecks": ["database", "algorithm"], "analysis_summary": "发现数据库查询和算法是主要瓶颈" } }, ... } }
常见坑和解决方案
▪ 坑1:任务拆解太细
现象: 拆出了几十个小任务,管理成本比任务本身还高。
例子:
任务:实现用户登录功能 错误的拆解: 1. 创建用户表 2. 创建用户模型 3. 创建用户控制器 4. 实现登录路由 5. 实现注册路由 6. 实现密码加密 7. 实现Token生成 8. 实现Token验证 9. 实现会话管理 10. 实现登录验证中间件 ...
原因: 拆解时没有考虑任务的原子性和内聚性。
解决方案:
正确的拆解: 1. 实现用户注册功能(包含数据库、模型、路由、验证) 2. 实现用户登录功能(包含密码加密、Token生成、会话管理) 3. 实现登录验证中间件
▪ 坑2:依赖关系混乱
现象: 任务之间形成循环依赖,导致无法执行。
例子:
任务1:优化数据库(依赖任务2的缓存设计) 任务2:设计缓存策略(依赖任务1的查询逻辑) 任务3:实现算法(依赖任务2的缓存策略)
原因: 没有进行依赖分析,直接根据经验拆解任务。
解决方案:
def validate_dependencies(task_tree: TaskTree) -> bool: """验证任务依赖是否合法(无循环依赖)""" # 实现拓扑排序 visited = set() temp_visited = set() def visit(task_id: str) -> bool: if task_id in temp_visited: return False # 循环依赖 if task_id in visited: return True temp_visited.add(task_id) task = task_tree.get_task(task_id) if task: for dep_id in task.dependencies: if not visit(dep_id): return False temp_visited.remove(task_id) visited.add(task_id) return True for task_id in task_tree.tasks: if not visit(task_id): return False return True
▪ 坑3:失败恢复策略不当
现象: 一个任务失败后,重试导致连锁反应,整个工作流崩溃。
例子:
任务3.2(实现缓存)失败 → 重试 → 还是失败 → 继续重试 → ... 无限循环,整个工作流卡住
原因: 没有设置最大重试次数,或者重试策略太激进。
解决方案:
def calculate_backoff(retry_count: int, base_delay: float = 1.0) -> float: """计算指数退避时间""" return base_delay * (2 ** retry_count) # 使用示例 if task.status == TaskStatus.FAILED: if task.retry_count < task.max_retries: # 指数退避 backoff = calculate_backoff(task.retry_count) time.sleep(backoff) self._retry_task(task) else: # 超过最大重试次数,决定是跳过还是终止 if task.priority == TaskPriority.LOW: self._skip_task(task) else: self.stop()
▪ 坑4:状态不一致
现象: 任务树状态和持久化状态不一致,导致恢复时出错。
例子:
系统崩溃前: - 任务1:COMPLETED - 任务2:RUNNING - 任务3:PENDING 系统重启后: - 任务1:COMPLETED - 任务2:PENDING(应该标记为FAILED) - 任务3:PENDING 任务2被重新执行,但实际上它之前已经部分完成了,导致数据不一致。
原因: 状态更新不是原子操作,或者在任务执行过程中崩溃。
解决方案:
def reset_running_tasks(task_tree: TaskTree): """重置所有运行中的任务""" for task in task_tree.tasks.values(): if task.status == TaskStatus.RUNNING: task.status = TaskStatus.PENDING task.start_time = None task.progress = 0.0
最后说句实话
Planning模式看起来复杂,但一旦你理解了它的核心思想,就会觉得非常自然。
核心思想就是两个词:分层、恢复。
分层:把复杂任务拆成多个层次,每层只关注自己的职责。
恢复:任何时候都能从失败中恢复,而不是从头开始。
但用好Planning模式,需要你真正理解业务。你需要知道:
这些都不是技术问题,是业务理解问题。
好的Planner不是代码写得好,而是对业务理解得深。
下一篇,我们讲Multi-Agent模式——多个Agent如何协作,解决更复杂的问题。
💡 一句话带走:Planning的核心不是"切任务",而是分层决策和失败恢复。理解业务的复杂度,才能设计出真正可用的Planner。
你用过Planning模式吗?遇到过哪些坑?评论区聊聊。