以下是添加了类型标注的代码:
fromtypingimportList,Tuple,Optional,Anyimportnumpyasnpimportmatplotlib.pyplotaspltfromscipy.spatialimportKDTreeimportmathimportrandomclassNode:"""RRT*节点类"""def__init__(self,x:float,y:float)->None:self.x:float=x self.y:float=y self.parent:Optional['Node']=Noneself.cost:float=0.0# 从起点到当前节点的代价def__repr__(self)->str:returnf"Node({self.x:.2f},{self.y:.2f})"classRRTStar:"""RRT*路径规划算法实现,使用KD-Tree进行最近邻搜索"""def__init__(self,start:Tuple[float,float],goal:Tuple[float,float],obstacles:List[Tuple[float,float,float]],bounds:Tuple[float,float,float,float],max_iter:int=1000,step_size:float=0.5,neighbor_radius:float=2.0)->None:""" 初始化RRT*算法 参数: start: 起点坐标 (x, y) goal: 终点坐标 (x, y) obstacles: 障碍物列表,每个障碍物为(x, y, radius) bounds: 地图边界 (x_min, x_max, y_min, y_max) max_iter: 最大迭代次数 step_size: 步长 neighbor_radius: 邻居搜索半径 """self.start:Node=Node(start[0],start[1])self.goal:Node=Node(goal[0],goal[1])self.obstacles:List[Tuple[float,float,float]]=obstacles self.bounds:Tuple[float,float,float,float]=bounds self.max_iter:int=max_iter self.step_size:float=step_size self.neighbor_radius:float=neighbor_radius# 节点列表self.nodes:List[Node]=[self.start]# 最终路径self.final_path:Optional[List[Tuple[float,float]]]=Nonedefdistance(self,node1:Node,node2:Node)->float:"""计算两个节点之间的欧几里得距离"""returnmath.sqrt((node1.x-node2.x)**2+(node1.y-node2.y)**2)defnearest(self,point:Node)->Optional[Node]:"""使用KD-Tree查找最近节点(替代线性搜索)"""# 构建KD-Treeiflen(self.nodes)==0:returnNone# 提取所有节点的坐标points:np.ndarray=np.array([[node.x,node.y]fornodeinself.nodes])kdtree:KDTree=KDTree(points)# 查询最近邻dist:floatidx:intdist,idx=kdtree.query([point.x,point.y])returnself.nodes[idx]defsteer(self,from_node:Node,to_node:Node)->Node:"""生成新节点(从from_node向to_node方向生长step_size距离)"""d:float=self.distance(from_node,to_node)# 如果距离小于步长,直接返回目标节点ifd<=self.step_size:new_node:Node=Node(to_node.x,to_node.y)else:# 计算方向向量theta:float=math.atan2(to_node.y-from_node.y,to_node.x-from_node.x)new_x:float=from_node.x+self.step_size*math.cos(theta)new_y:float=from_node.y+self.step_size*math.sin(theta)new_node=Node(new_x,new_y)returnnew_nodedefis_collision_free(self,node1:Node,node2:Node)->bool:"""检查两点之间路径是否与障碍物碰撞"""# 采样点检查碰撞num_check:int=10foriinrange(num_check+1):t:float=i/num_check x:float=node1.x+t*(node2.x-node1.x)y:float=node1.y+t*(node2.y-node1.y)for(ox,oy,radius)inself.obstacles:dist:float=math.sqrt((x-ox)**2+(y-oy)**2)ifdist<=radius:returnFalsereturnTruedeffind_near_nodes(self,node:Node)->List[Node]:"""使用KD-Tree在半径内查找邻居节点"""iflen(self.nodes)<2:return[]# 构建KD-Treepoints:np.ndarray=np.array([[n.x,n.y]forninself.nodes])kdtree:KDTree=KDTree(points)# 半径查询indices:List[int]=kdtree.query_ball_point([node.x,node.y],self.neighbor_radius)# 排除节点自身(如果是已存在的节点)near_nodes:List[Node]=[self.nodes[i]foriinindicesifself.nodes[i]!=node]returnnear_nodesdefchoose_parent(self,new_node:Node,near_nodes:List[Node])->bool:"""为new_node选择最优父节点"""min_cost:float=float('inf')best_parent:Optional[Node]=Nonefornear_nodeinnear_nodes:# 检查是否无碰撞ifself.is_collision_free(near_node,new_node):# 计算通过near_node到达new_node的代价cost:float=near_node.cost+self.distance(near_node,new_node)ifcost<min_cost:min_cost=cost best_parent=near_nodeifbest_parentisnotNone:new_node.parent=best_parent new_node.cost=min_costreturnTruereturnFalsedefrewire(self,new_node:Node,near_nodes:List[Node])->int:"""重连接函数 - RRT*算法的核心优化步骤"""rewire_count:int=0fornear_nodeinnear_nodes:# 检查new_node是否可以成为near_node的更好父节点ifnear_node==new_node.parent:continue# 检查是否无碰撞ifself.is_collision_free(new_node,near_node):# 计算通过new_node到达near_node的新代价new_cost:float=new_node.cost+self.distance(new_node,near_node)# 如果新代价更小,则重连接ifnew_cost<near_node.cost:near_node.parent=new_node near_node.cost=new_cost rewire_count+=1# 递归更新子节点的代价self.update_children_cost(near_node)returnrewire_countdefupdate_children_cost(self,parent_node:Node)->None:"""递归更新子节点的代价"""# 查找所有子节点(注意:这里简化处理,实际需要维护子节点列表)fornodeinself.nodes:ifnode.parent==parent_node:node.cost=parent_node.cost+self.distance(parent_node,node)self.update_children_cost(node)defrandom_node(self)->Node:"""生成随机节点(90%偏向目标点)"""ifrandom.random()>0.1:returnself.goal x_min,x_max,y_min,y_max=self.boundsreturnNode(random.uniform(x_min,x_max),random.uniform(y_min,y_max))defcheck_goal(self,node:Node)->bool:"""检查是否到达目标点附近"""returnself.distance(node,self.goal)<=self.step_sizedeffind_path(self)->Optional[List[Tuple[float,float]]]:"""执行RRT*路径规划"""foriterationinrange(self.max_iter):# 1. 生成随机节点random_node:Node=self.random_node()# 2. 查找最近节点nearest_node:Optional[Node]=self.nearest(random_node)ifnearest_nodeisNone:continue# 3. 生成新节点new_node:Node=self.steer(nearest_node,random_node)# 4. 检查碰撞ifnotself.is_collision_free(nearest_node,new_node):continue# 5. 查找邻居节点near_nodes:List[Node]=self.find_near_nodes(new_node)# 6. 选择最优父节点ifnotself.choose_parent(new_node,near_nodes):continue# 7. 添加到节点列表self.nodes.append(new_node)# 8. 执行重连接(REWIRE)self.rewire(new_node,near_nodes)# 9. 检查是否到达目标ifself.check_goal(new_node):# 尝试将目标节点连接到路径ifself.is_collision_free(new_node,self.goal):self.goal.parent=new_node self.goal.cost=new_node.cost+self.distance(new_node,self.goal)self.nodes.append(self.goal)print(f"找到路径!迭代次数:{iteration}, 节点数:{len(self.nodes)}")break# 提取最终路径returnself.extract_path()defextract_path(self)->Optional[List[Tuple[float,float]]]:"""从目标节点回溯提取路径"""ifself.goal.parentisNone:returnNonepath:List[Tuple[float,float]]=[]node:Optional[Node]=self.goalwhilenodeisnotNone:path.append((node.x,node.y))node=node.parent path.reverse()self.final_path=pathreturnpathdefget_path_cost(self)->float:"""计算路径代价"""ifself.final_pathisNone:returnfloat('inf')cost:float=0.0foriinrange(len(self.final_path)-1):x1,y1=self.final_path[i]x2,y2=self.final_path[i+1]cost+=math.sqrt((x2-x1)**2+(y2-y1)**2)returncostdefvisualize(self,show:bool=True)->Any:# 返回类型是matplotlib.figure.Figure"""可视化结果"""plt.figure(figsize=(10,10))# 绘制障碍物for(x,y,radius)inself.obstacles:circle=plt.Circle((x,y),radius,color='gray',alpha=0.5)plt.gca().add_patch(circle)# 绘制所有节点和连接fornodeinself.nodes:ifnode.parentisnotNone:plt.plot([node.x,node.parent.x],[node.y,node.parent.y],'lightgray',linewidth=0.5,alpha=0.5)plt.plot(node.x,node.y,'o',markersize=3,color='blue',alpha=0.3)# 绘制起点和终点plt.plot(self.start.x,self.start.y,'ro',markersize=10,label='起点')plt.plot(self.goal.x,self.goal.y,'go',markersize=10,label='终点')# 绘制最终路径ifself.final_pathisnotNone:path_x:List[float]=[p[0]forpinself.final_path]path_y:List[float]=[p[1]forpinself.final_path]plt.plot(path_x,path_y,'r-',linewidth=2,label='最终路径')print(f"路径长度:{self.get_path_cost():.2f}")# 设置图形属性x_min,x_max,y_min,y_max=self.bounds plt.xlim(x_min,x_max)plt.ylim(y_min,y_max)plt.grid(True,alpha=0.3)plt.legend()plt.title(f'RRT* 路径规划 (节点数:{len(self.nodes)})')plt.xlabel('X')plt.ylabel('Y')ifshow:plt.show()returnplt.gcf()# 示例使用defmain()->None:# 设置参数start:Tuple[float,float]=(0,0)goal:Tuple[float,float]=(10,10)bounds:Tuple[float,float,float,float]=(-2,12,-2,12)# (x_min, x_max, y_min, y_max)# 创建障碍物obstacles:List[Tuple[float,float,float]]=[(3,3,1.5),(6,6,1.2),(8,2,1.0),(4,8,1.3),(7,8,1.0),(2,5,0.8),(5,2,0.7),(9,6,1.1)]# 创建RRT*规划器rrt_star:RRTStar=RRTStar(start=start,goal=goal,obstacles=obstacles,bounds=bounds,max_iter=2000,step_size=0.5,neighbor_radius=2.0)# 执行路径规划print("开始RRT*路径规划...")path:Optional[List[Tuple[float,float]]]=rrt_star.find_path()ifpathisNone:print("未找到路径!")else:print(f"找到路径,包含{len(path)}个点")print(f"路径总代价:{rrt_star.get_path_cost():.2f}")# 可视化rrt_star.visualize()if__name__=="__main__":main()主要添加的类型标注包括:
- 函数参数和返回类型标注:为所有方法添加了参数类型和返回类型
- 类属性类型标注:为类的所有属性添加了类型标注
- 局部变量类型标注:为重要的局部变量添加了类型标注
- 导入类型模块:添加了 from typing import … 导入
- 特殊类型处理:
· 使用 Optional[] 表示可能为 None 的返回值
· 使用 Tuple[] 表示元组类型
· 使用 List[] 表示列表类型
· 对于 visualize 方法的返回类型,使用 Any 因为 plt.gcf() 返回的是复杂的 matplotlib 图形对象
这些类型标注使代码更加清晰,有助于静态类型检查工具(如 mypy)进行错误检查,也提高了代码的可读性和可维护性。