diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala index e5a4d82b9874c..bc922f854f3aa 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala @@ -18,14 +18,15 @@ package org.apache.spark.graphx.lib import org.apache.spark.graphx._ +import scala.reflect.ClassTag object ShortestPaths { type SPMap = Map[VertexId, Int] // map of landmarks -> minimum distance to landmark def SPMap(x: (VertexId, Int)*) = Map(x: _*) def increment(spmap: SPMap): SPMap = spmap.map { case (v, d) => v -> (d + 1) } def plus(spmap1: SPMap, spmap2: SPMap): SPMap = - (spmap1.keySet ++ spmap2.keySet).map{ - k => k -> scala.math.min(spmap1.getOrElse(k, Int.MaxValue), spmap2.getOrElse(k, Int.MaxValue)) + (spmap1.keySet ++ spmap2.keySet).map { + k => k -> math.min(spmap1.getOrElse(k, Int.MaxValue), spmap2.getOrElse(k, Int.MaxValue)) }.toMap /** @@ -33,9 +34,7 @@ object ShortestPaths { * return an RDD with the map of landmarks to their shortest-path * lengths. * - * @tparam VD the shortest paths map for the vertex - * @tparam ED the incremented shortest-paths map of the originating - * vertex (discarded in the computation) + * @tparam ED the edge attribute type (not used in the computation) * * @param graph the graph for which to compute the shortest paths * @param landmarks the list of landmark vertex ids @@ -43,15 +42,12 @@ object ShortestPaths { * @return a graph with vertex attributes containing a map of the * shortest paths to each landmark */ - def run[VD, ED](graph: Graph[VD, ED], landmarks: Seq[VertexId]) - (implicit m1: Manifest[VD], m2: Manifest[ED]): Graph[SPMap, SPMap] = { - + def run[ED: ClassTag](graph: Graph[_, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = { val spGraph = graph - .mapVertices{ (vid, attr) => + .mapVertices { (vid, attr) => if (landmarks.contains(vid)) SPMap(vid -> 0) else SPMap() } - .mapTriplets{ edge => edge.srcAttr } val initialMessage = SPMap() @@ -59,7 +55,7 @@ object ShortestPaths { plus(attr, msg) } - def sendMessage(edge: EdgeTriplet[SPMap, SPMap]): Iterator[(VertexId, SPMap)] = { + def sendMessage(edge: EdgeTriplet[SPMap, _]): Iterator[(VertexId, SPMap)] = { val newAttr = increment(edge.srcAttr) if (edge.dstAttr != plus(newAttr, edge.dstAttr)) Iterator((edge.dstId, newAttr)) else Iterator.empty diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala index d095d3e791b5b..265827b3341c2 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala @@ -30,13 +30,18 @@ class ShortestPathsSuite extends FunSuite with LocalSparkContext { test("Shortest Path Computations") { withSpark { sc => - val shortestPaths = Set((1,Map(1 -> 0, 4 -> 2)), (2,Map(1 -> 1, 4 -> 2)), (3,Map(1 -> 2, 4 -> 1)), - (4,Map(1 -> 2, 4 -> 0)), (5,Map(1 -> 1, 4 -> 1)), (6,Map(1 -> 3, 4 -> 1))) - val edgeSeq = Seq((1, 2), (1, 5), (2, 3), (2, 5), (3, 4), (4, 5), (4, 6)).flatMap{ case e => Seq(e, e.swap) } + val shortestPaths = Set( + (1, Map(1 -> 0, 4 -> 2)), (2, Map(1 -> 1, 4 -> 2)), (3, Map(1 -> 2, 4 -> 1)), + (4, Map(1 -> 2, 4 -> 0)), (5, Map(1 -> 1, 4 -> 1)), (6, Map(1 -> 3, 4 -> 1))) + val edgeSeq = Seq((1, 2), (1, 5), (2, 3), (2, 5), (3, 4), (4, 5), (4, 6)).flatMap { + case e => Seq(e, e.swap) + } val edges = sc.parallelize(edgeSeq).map { case (v1, v2) => (v1.toLong, v2.toLong) } val graph = Graph.fromEdgeTuples(edges, 1) val landmarks = Seq(1, 4).map(_.toLong) - val results = ShortestPaths.run(graph, landmarks).vertices.collect.map { case (v, spMap) => (v, spMap.mapValues(_.get)) } + val results = ShortestPaths.run(graph, landmarks).vertices.collect.map { + case (v, spMap) => (v, spMap.mapValues(_.get)) + } assert(results.toSet === shortestPaths) } }