频道栏目
首页 > 资讯 > 其他综合 > 正文

Spark GraphX之Dijkstra(单源最短路径)、Prime(最小生成树)、FloydWarshall(多源最短路径)

18-07-20        来源:[db:作者]  
收藏   我要投稿

输入图片说明

ShortestPaths的源码如下:

package org.apache.spark.graphx.lib

import scala.reflect.ClassTag

import org.apache.spark.graphx._

/**
 * Computes shortest paths to the given set of landmark vertices, returning a graph where each
 * vertex attribute is a map containing the shortest-path distance to each reachable landmark.
 */
object ShortestPaths {
  /** Stores a map from the vertex id of a landmark to the distance to that landmark. */
  type SPMap = Map[VertexId, Int]

  private def makeMap(x: (VertexId, Int)*) = Map(x: _*)

  private def incrementMap(spmap: SPMap): SPMap = spmap.map { case (v, d) => v -> (d + 1) }

  private def addMaps(spmap1: SPMap, spmap2: SPMap): SPMap =
    (spmap1.keySet ++ spmap2.keySet).map {
      k => k -> math.min(spmap1.getOrElse(k, Int.MaxValue), spmap2.getOrElse(k, Int.MaxValue))
    }.toMap

  /**
   * Computes shortest paths to the given set of landmark vertices.
   *
   * @tparam ED the edge attribute type (not used in the computation)
   *
   * @param graph the graph for which to compute the shortest paths
   * @param landmarks the list of landmark vertex ids. Shortest paths will be computed to each
   * landmark.
   *
   * @return a graph where each vertex attribute is a map containing the shortest-path distance to
   * each reachable landmark vertex.
   */
  def run[VD, ED: ClassTag](graph: Graph[VD, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = {
    val spGraph = graph.mapVertices { (vid, attr) =>
      if (landmarks.contains(vid)) makeMap(vid -> 0) else makeMap()
    }

    val initialMessage = makeMap()

    def vertexProgram(id: VertexId, attr: SPMap, msg: SPMap): SPMap = {
      addMaps(attr, msg)
    }

    def sendMessage(edge: EdgeTriplet[SPMap, _]): Iterator[(VertexId, SPMap)] = {
      val newAttr = incrementMap(edge.dstAttr)
      if (edge.srcAttr != addMaps(newAttr, edge.srcAttr)) Iterator((edge.srcId, newAttr))
      else Iterator.empty
    }

    Pregel(spGraph, initialMessage)(vertexProgram, sendMessage, addMaps)
  }
}

关于单源最短路径,我们可以调用 ShortestPaths .run(graph, landmarks) 得到graph中的顶点到landmarks的“距离”,但是这个“距离”只是“跳数”。换句话说,只在graph中每条边的权重都为1的情况下,才能保证结果的正确性。而现实情况中,往往都不满足这个条件。那么问题来了,我们该如何做呢?学过图论的朋友都知道,Dijkstra算法可以解决这个问题。遗憾的是,GraphX目前(Spark2.0.2)并未提供这样的API,所以基于GraphX实现Dijkstra算法变得很有必要。

Dijkstra(单源最短路径)

  //单源最短路径
  def dijkstra[VD: ClassTag](g : Graph[VD, Double], origin: VertexId) = {
    //初始化,其中属性为(boolean, double,Long)类型,boolean用于标记是否访问过,double为顶点距离原点的距离,Long是上一个顶点的id
    var g2 = g.mapVertices((vid, _) => (false, if(vid == origin) 0 else Double.MaxValue, -1L))

    for(i <- 1L to g.vertices.count()) {
      //从没有访问过的顶点中找出距离原点最近的点
      val currentVertexId = g2.vertices.filter(! _._2._1).reduce((a,b) => if (a._2._2 < b._2._2) a else b)._1
      //更新currentVertexId邻接顶点的‘double’值
      val newDistances = g2.aggregateMessages[(Double, Long)](
        triplet => if(triplet.srcId == currentVertexId && !triplet.dstAttr._1) {    //只给未确定的顶点发送消息
          triplet.sendToDst((triplet.srcAttr._2 + triplet.attr, triplet.srcId))
        },
        (x, y) => if(x._1 < y._1) x else y ,
        TripletFields.All
      )
      //newDistances.foreach(x => println("currentVertexId\t"+currentVertexId+"\t->\t"+x))
      //更新图形
      g2 = g2.outerJoinVertices(newDistances) {
        case (vid, vd, Some(newSum)) => (vd._1 || vid == currentVertexId, math.min(vd._2, newSum._1), if(vd._2 <= newSum._1) vd._3 else newSum._2 )
        case (vid, vd, None) => (vd._1|| vid == currentVertexId, vd._2, vd._3)
      }
      //g2.vertices.foreach(x => println("currentVertexId\t"+currentVertexId+"\t->\t"+x))
    }

    //g2
    g.outerJoinVertices(g2.vertices)( (vid, srcAttr, dist) => (srcAttr, dist.getOrElse(false, Double.MaxValue, -1)._2, dist.getOrElse(false, Double.MaxValue, -1)._3) )
  }

Prime(最小生成树)

知道Dijkstra算法的人也一定知道Prime算法。

  //最小生成树
  def prime[VD: ClassTag](g : Graph[VD, Double], origin: VertexId) = {
    //初始化,其中属性为(boolean, double,Long)类型,boolean用于标记是否访问过,double为加入当前顶点的代价,Long是上一个顶点的id
    var g2 = g.mapVertices((vid, _) => (false, if(vid == origin) 0 else Double.MaxValue, -1L))

    for(i <- 1L to g.vertices.count()) {
      //从没有访问过的顶点中找出 代价最小 的点
      val currentVertexId = g2.vertices.filter(! _._2._1).reduce((a,b) => if (a._2._2 < b._2._2) a else b)._1
      //更新currentVertexId邻接顶点的‘double’值
      val newDistances = g2.aggregateMessages[(Double, Long)](
        triplet => if(triplet.srcId == currentVertexId && !triplet.dstAttr._1) {    //只给未确定的顶点发送消息
          triplet.sendToDst((triplet.attr, triplet.srcId))
        },
        (x, y) => if(x._1 < y._1) x else y ,
        TripletFields.All
      )
      //newDistances.foreach(x => println("currentVertexId\t"+currentVertexId+"\t->\t"+x))
      //更新图形
      g2 = g2.outerJoinVertices(newDistances) {
        case (vid, vd, Some(newSum)) => (vd._1 || vid == currentVertexId, math.min(vd._2, newSum._1), if(vd._2 <= newSum._1) vd._3 else newSum._2 )
        case (vid, vd, None) => (vd._1|| vid == currentVertexId, vd._2, vd._3)
      }
      //g2.vertices.foreach(x => println("currentVertexId\t"+currentVertexId+"\t->\t"+x))
    }

    //g2
    g.outerJoinVertices(g2.vertices)( (vid, srcAttr, dist) => (srcAttr, dist.getOrElse(false, Double.MaxValue, -1)._2, dist.getOrElse(false, Double.MaxValue, -1)._3) )
  }

FloydWarshall(多源最短路径)

  //多源最短路径
  def floydWarshall[VD: ClassTag](g: Graph[VD, Double]) = {
    def mergeMaps(a: Map[VertexId, Double], b: Map[VertexId, Double]) = {
      (a.keySet ++ b.keySet).map{ k => (k, math.min(a.getOrElse(k, Double.MaxValue), b.getOrElse(k, Double.MaxValue))) }.toMap
    }

    val N = g.vertices.count()    //图顶点的个数
    var n = -1
    //初始化图
    var g2 = g.mapVertices( (vid, _) => Map(vid -> 0.0) )

    //当n = N*N时,退出循环。注:不难发现最终结果是一个实对称矩阵
    while(n < N * N) {
      val newVertices = g2.aggregateMessages[Map[VertexId, Double]](
        triplet =>{
          val dstPlus = triplet.dstAttr.map{ case (vid, distance) => (vid, triplet.attr+distance) }
          if(dstPlus != triplet.srcAttr) { triplet.sendToSrc(dstPlus) }
        },
        (a, b) => mergeMaps(a, b) ,
        TripletFields.Dst
      )

      g2 = g2.outerJoinVertices(newVertices)( (_, oldAttr, opt) => mergeMaps(oldAttr, opt.get) )

      n = g2.vertices.map{ case (vid, srcAttr) => srcAttr.size }.reduce(_ + _)
      //println("number\t" + n)
    }

    g2
  }

纸上得来终觉浅,绝知此事要躬行。下面开始实战、实战、实战,重要的事情说三遍!!!

    val myVertices = sc.makeRDD(Array((1L, "A"), (2L, "B"), (3L, "C"), (4L, "D"), (5L, "E"), (6L, "F"), (7L, "G")))
    val initialEdges = sc.makeRDD(Array(Edge(1L, 2L, 7.0), Edge(1L, 4L, 5.0),
                                   Edge(2L, 3L, 8.0), Edge(2L, 4L, 9.0), Edge(2L, 5L, 7.0),
                                   Edge(3L, 5L, 5.0),
                                   Edge(4L, 5L, 15.0), Edge(4L, 6L, 6.0),
                                   Edge(5L, 6L, 8.0), Edge(5L, 7L, 9.0),
                                   Edge(6L, 7L, 11.0)))
    val myEdges = initialEdges.filter(e => e.srcId != e.dstId).flatMap(e => Array(e, Edge(e.dstId, e.srcId, e.attr))).distinct()  //去掉自循环边,有向图变为无向图,去除重复边
    val myGraph = Graph(myVertices, myEdges).cache()

    println(ShortestPaths.run(myGraph, Seq(3)).vertices.collect().mkString(","))
    println(dijkstra(myGraph, 3L).vertices.map(x => (x._1, x._2)).collect().mkString(" | "))
    println(prime(myGraph, 3L).vertices.map(x => (x._1, x._2)).collect().mkString(" | "))
    floydWarshall(myGraph).vertices.foreach(println)

输出依次如下:

ShortestPaths:
(1,Map(3 -> 2)) | (2,Map(3 -> 1)) | (3,Map(3 -> 0)) | (4,Map(3 -> 2)) | (5,Map(3 -> 1)) | (6,Map(3 -> 2)) | (7,Map(3 -> 2))
Dijkstra:
(1,(A,15.0,2)) | (2,(B,8.0,3)) | (3,(C,0.0,-1)) | (4,(D,17.0,2)) | (5,(E,5.0,3)) | (6,(F,13.0,5)) | (7,(G,14.0,5))
Prime:
(1,(A,7.0,2)) | (2,(B,7.0,5)) | (3,(C,0.0,-1)) | (4,(D,5.0,1)) | (5,(E,5.0,3)) | (6,(F,6.0,4)) | (7,(G,9.0,5))
FloydWarshall:
(4,Map(5 -> 14.0, 1 -> 5.0, 6 -> 6.0, 2 -> 9.0, 7 -> 17.0, 3 -> 17.0, 4 -> 0.0))
(2,Map(5 -> 7.0, 1 -> 7.0, 6 -> 15.0, 2 -> 0.0, 7 -> 16.0, 3 -> 8.0, 4 -> 9.0))
(7,Map(5 -> 9.0, 1 -> 22.0, 6 -> 11.0, 2 -> 16.0, 7 -> 0.0, 3 -> 14.0, 4 -> 17.0))
(5,Map(5 -> 0.0, 1 -> 14.0, 6 -> 8.0, 2 -> 7.0, 7 -> 9.0, 3 -> 5.0, 4 -> 14.0))
(3,Map(5 -> 5.0, 1 -> 15.0, 6 -> 13.0, 2 -> 8.0, 7 -> 14.0, 3 -> 0.0, 4 -> 17.0))
(1,Map(5 -> 14.0, 1 -> 0.0, 6 -> 11.0, 2 -> 7.0, 7 -> 22.0, 3 -> 15.0, 4 -> 5.0))
(6,Map(5 -> 8.0, 1 -> 11.0, 6 -> 0.0, 2 -> 15.0, 7 -> 11.0, 3 -> 13.0, 4 -> 6.0))
相关TAG标签
上一篇:Mac下常用的操作命令记录
下一篇:python关于引用,装饰器,列表知识讲解
相关文章
图文推荐

关于我们 | 联系我们 | 广告服务 | 投资合作 | 版权申明 | 在线帮助 | 网站地图 | 作品发布 | Vip技术培训 | 举报中心

版权所有: 红黑联盟--致力于做实用的IT技术学习网站