参考:算法图解
# 在未处理的节点中找到开销最小的节点 def find_lowest_cost_node(costs, processed): lowest = float("inf") lowest_cost_node = None for node in costs: cost = costs[node] if cost < lowest and node not in processed: lowest = cost lowest_cost_node = node return lowest_cost_node def func(graph, costs, parents): processed = [] # 记录处理过的节点 node = find_lowest_cost_node(costs, processed) # 在未处理的节点中找到开销最小的节点 while node: cost = costs[node] neighbors = graph[node] for n in neighbors.keys(): # 更新到达邻居节点的花费 new_cost = cost + neighbors[n] if costs[n] > new_cost: costs[n] = new_cost parents[n] = node # 设置父节点 processed.append(node) node = find_lowest_cost_node(costs, processed) return costs["final"] if __name__ == '__main__': # graph dict # 记录每个节点的到邻居的花费 graph = {"start": {}, "a": {}, "b": {}, "final": {}} graph["start"]["a"] = 6 graph["start"]["b"] = 2 graph["a"]["final"] = 1 graph["b"]["a"] = 3 graph["b"]["final"] = 5 # cost dict # 记录到达每个节点的最小花费 infinity = float("inf") costs = {"a": 6, "b": 2, "final": infinity} # parents dict # 记录每个节点的前一个节点(花费最少) parents = {"a": "start", "b": "start", "final": None} print(func(graph, costs, parents))