def prim(graph,n):
state = [0 for i in range(n)]
dist = [float("inf") for _ in range(n)]
result = []
for i in range(n):
if dist[0] == float("inf"):
idx = 0
dist[0] = 0
else:
mx = float("inf")
for i in range(n):
if dist[i] < mx and not state[i]:
mx =dist[i]
idx = i
result.append(idx)
state[idx] = 1
for i in range(n):
if graph[i][idx] < dist[i] and not state[i]:
dist[i] = graph[i][idx]
return result
if __name__ == '__main__':
graph = [[1,5,3,6],
[5,8,9,6],
[3,9,8,6],
[6,6,6,3]]
print(prim(graph,4))