ShortestPaths的源码如下:
package org.apache.spark.graphx.lib import scala.reflect.ClassTag import org.apache.spark.graphx._ /** * Computes shortest paths to the given set of landmark vertices, returning a graph where each * vertex attribute is a map containing the shortest-path distance to each reachable landmark. */ object ShortestPaths { /** Stores a map from the vertex id of a landmark to the distance to that landmark. */ type SPMap = Map[VertexId, Int] private def makeMap(x: (VertexId, Int)*) = Map(x: _*) private def incrementMap(spmap: SPMap): SPMap = spmap.map { case (v, d) => v -> (d + 1) } private def addMaps(spmap1: SPMap, spmap2: SPMap): SPMap = (spmap1.keySet ++ spmap2.keySet).map { k => k -> math.min(spmap1.getOrElse(k, Int.MaxValue), spmap2.getOrElse(k, Int.MaxValue)) }.toMap /** * Computes shortest paths to the given set of landmark vertices. * * @tparam ED the edge attribute type (not used in the computation) * * @param graph the graph for which to compute the shortest paths * @param landmarks the list of landmark vertex ids. Shortest paths will be computed to each * landmark. * * @return a graph where each vertex attribute is a map containing the shortest-path distance to * each reachable landmark vertex. */ def run[VD, ED: ClassTag](graph: Graph[VD, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = { val spGraph = graph.mapVertices { (vid, attr) => if (landmarks.contains(vid)) makeMap(vid -> 0) else makeMap() } val initialMessage = makeMap() def vertexProgram(id: VertexId, attr: SPMap, msg: SPMap): SPMap = { addMaps(attr, msg) } def sendMessage(edge: EdgeTriplet[SPMap, _]): Iterator[(VertexId, SPMap)] = { val newAttr = incrementMap(edge.dstAttr) if (edge.srcAttr != addMaps(newAttr, edge.srcAttr)) Iterator((edge.srcId, newAttr)) else Iterator.empty } Pregel(spGraph, initialMessage)(vertexProgram, sendMessage, addMaps) } }
关于单源最短路径,我们可以调用 ShortestPaths .run(graph, landmarks) 得到graph中的顶点到landmarks的“距离”,但是这个“距离”只是“跳数”。换句话说,只在graph中每条边的权重都为1的情况下,才能保证结果的正确性。而现实情况中,往往都不满足这个条件。那么问题来了,我们该如何做呢?学过图论的朋友都知道,Dijkstra算法可以解决这个问题。遗憾的是,GraphX目前(Spark2.0.2)并未提供这样的API,所以基于GraphX实现Dijkstra算法变得很有必要。
//单源最短路径 def dijkstra[VD: ClassTag](g : Graph[VD, Double], origin: VertexId) = { //初始化,其中属性为(boolean, double,Long)类型,boolean用于标记是否访问过,double为顶点距离原点的距离,Long是上一个顶点的id var g2 = g.mapVertices((vid, _) => (false, if(vid == origin) 0 else Double.MaxValue, -1L)) for(i <- 1L to g.vertices.count()) { //从没有访问过的顶点中找出距离原点最近的点 val currentVertexId = g2.vertices.filter(! _._2._1).reduce((a,b) => if (a._2._2 < b._2._2) a else b)._1 //更新currentVertexId邻接顶点的‘double’值 val newDistances = g2.aggregateMessages[(Double, Long)]( triplet => if(triplet.srcId == currentVertexId && !triplet.dstAttr._1) { //只给未确定的顶点发送消息 triplet.sendToDst((triplet.srcAttr._2 + triplet.attr, triplet.srcId)) }, (x, y) => if(x._1 < y._1) x else y , TripletFields.All ) //newDistances.foreach(x => println("currentVertexId\t"+currentVertexId+"\t->\t"+x)) //更新图形 g2 = g2.outerJoinVertices(newDistances) { case (vid, vd, Some(newSum)) => (vd._1 || vid == currentVertexId, math.min(vd._2, newSum._1), if(vd._2 <= newSum._1) vd._3 else newSum._2 ) case (vid, vd, None) => (vd._1|| vid == currentVertexId, vd._2, vd._3) } //g2.vertices.foreach(x => println("currentVertexId\t"+currentVertexId+"\t->\t"+x)) } //g2 g.outerJoinVertices(g2.vertices)( (vid, srcAttr, dist) => (srcAttr, dist.getOrElse(false, Double.MaxValue, -1)._2, dist.getOrElse(false, Double.MaxValue, -1)._3) ) }
知道Dijkstra算法的人也一定知道Prime算法。
//最小生成树 def prime[VD: ClassTag](g : Graph[VD, Double], origin: VertexId) = { //初始化,其中属性为(boolean, double,Long)类型,boolean用于标记是否访问过,double为加入当前顶点的代价,Long是上一个顶点的id var g2 = g.mapVertices((vid, _) => (false, if(vid == origin) 0 else Double.MaxValue, -1L)) for(i <- 1L to g.vertices.count()) { //从没有访问过的顶点中找出 代价最小 的点 val currentVertexId = g2.vertices.filter(! _._2._1).reduce((a,b) => if (a._2._2 < b._2._2) a else b)._1 //更新currentVertexId邻接顶点的‘double’值 val newDistances = g2.aggregateMessages[(Double, Long)]( triplet => if(triplet.srcId == currentVertexId && !triplet.dstAttr._1) { //只给未确定的顶点发送消息 triplet.sendToDst((triplet.attr, triplet.srcId)) }, (x, y) => if(x._1 < y._1) x else y , TripletFields.All ) //newDistances.foreach(x => println("currentVertexId\t"+currentVertexId+"\t->\t"+x)) //更新图形 g2 = g2.outerJoinVertices(newDistances) { case (vid, vd, Some(newSum)) => (vd._1 || vid == currentVertexId, math.min(vd._2, newSum._1), if(vd._2 <= newSum._1) vd._3 else newSum._2 ) case (vid, vd, None) => (vd._1|| vid == currentVertexId, vd._2, vd._3) } //g2.vertices.foreach(x => println("currentVertexId\t"+currentVertexId+"\t->\t"+x)) } //g2 g.outerJoinVertices(g2.vertices)( (vid, srcAttr, dist) => (srcAttr, dist.getOrElse(false, Double.MaxValue, -1)._2, dist.getOrElse(false, Double.MaxValue, -1)._3) ) }
//多源最短路径 def floydWarshall[VD: ClassTag](g: Graph[VD, Double]) = { def mergeMaps(a: Map[VertexId, Double], b: Map[VertexId, Double]) = { (a.keySet ++ b.keySet).map{ k => (k, math.min(a.getOrElse(k, Double.MaxValue), b.getOrElse(k, Double.MaxValue))) }.toMap } val N = g.vertices.count() //图顶点的个数 var n = -1 //初始化图 var g2 = g.mapVertices( (vid, _) => Map(vid -> 0.0) ) //当n = N*N时,退出循环。注:不难发现最终结果是一个实对称矩阵 while(n < N * N) { val newVertices = g2.aggregateMessages[Map[VertexId, Double]]( triplet =>{ val dstPlus = triplet.dstAttr.map{ case (vid, distance) => (vid, triplet.attr+distance) } if(dstPlus != triplet.srcAttr) { triplet.sendToSrc(dstPlus) } }, (a, b) => mergeMaps(a, b) , TripletFields.Dst ) g2 = g2.outerJoinVertices(newVertices)( (_, oldAttr, opt) => mergeMaps(oldAttr, opt.get) ) n = g2.vertices.map{ case (vid, srcAttr) => srcAttr.size }.reduce(_ + _) //println("number\t" + n) } g2 }
纸上得来终觉浅,绝知此事要躬行。下面开始实战、实战、实战,重要的事情说三遍!!!
val myVertices = sc.makeRDD(Array((1L, "A"), (2L, "B"), (3L, "C"), (4L, "D"), (5L, "E"), (6L, "F"), (7L, "G"))) val initialEdges = sc.makeRDD(Array(Edge(1L, 2L, 7.0), Edge(1L, 4L, 5.0), Edge(2L, 3L, 8.0), Edge(2L, 4L, 9.0), Edge(2L, 5L, 7.0), Edge(3L, 5L, 5.0), Edge(4L, 5L, 15.0), Edge(4L, 6L, 6.0), Edge(5L, 6L, 8.0), Edge(5L, 7L, 9.0), Edge(6L, 7L, 11.0))) val myEdges = initialEdges.filter(e => e.srcId != e.dstId).flatMap(e => Array(e, Edge(e.dstId, e.srcId, e.attr))).distinct() //去掉自循环边,有向图变为无向图,去除重复边 val myGraph = Graph(myVertices, myEdges).cache() println(ShortestPaths.run(myGraph, Seq(3)).vertices.collect().mkString(",")) println(dijkstra(myGraph, 3L).vertices.map(x => (x._1, x._2)).collect().mkString(" | ")) println(prime(myGraph, 3L).vertices.map(x => (x._1, x._2)).collect().mkString(" | ")) floydWarshall(myGraph).vertices.foreach(println)
输出依次如下:
ShortestPaths: (1,Map(3 -> 2)) | (2,Map(3 -> 1)) | (3,Map(3 -> 0)) | (4,Map(3 -> 2)) | (5,Map(3 -> 1)) | (6,Map(3 -> 2)) | (7,Map(3 -> 2)) Dijkstra: (1,(A,15.0,2)) | (2,(B,8.0,3)) | (3,(C,0.0,-1)) | (4,(D,17.0,2)) | (5,(E,5.0,3)) | (6,(F,13.0,5)) | (7,(G,14.0,5)) Prime: (1,(A,7.0,2)) | (2,(B,7.0,5)) | (3,(C,0.0,-1)) | (4,(D,5.0,1)) | (5,(E,5.0,3)) | (6,(F,6.0,4)) | (7,(G,9.0,5)) FloydWarshall: (4,Map(5 -> 14.0, 1 -> 5.0, 6 -> 6.0, 2 -> 9.0, 7 -> 17.0, 3 -> 17.0, 4 -> 0.0)) (2,Map(5 -> 7.0, 1 -> 7.0, 6 -> 15.0, 2 -> 0.0, 7 -> 16.0, 3 -> 8.0, 4 -> 9.0)) (7,Map(5 -> 9.0, 1 -> 22.0, 6 -> 11.0, 2 -> 16.0, 7 -> 0.0, 3 -> 14.0, 4 -> 17.0)) (5,Map(5 -> 0.0, 1 -> 14.0, 6 -> 8.0, 2 -> 7.0, 7 -> 9.0, 3 -> 5.0, 4 -> 14.0)) (3,Map(5 -> 5.0, 1 -> 15.0, 6 -> 13.0, 2 -> 8.0, 7 -> 14.0, 3 -> 0.0, 4 -> 17.0)) (1,Map(5 -> 14.0, 1 -> 0.0, 6 -> 11.0, 2 -> 7.0, 7 -> 22.0, 3 -> 15.0, 4 -> 5.0)) (6,Map(5 -> 8.0, 1 -> 11.0, 6 -> 0.0, 2 -> 15.0, 7 -> 11.0, 3 -> 13.0, 4 -> 6.0))