Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.mllib.clustering

import scala.collection.mutable.ArrayBuffer

import breeze.linalg.{DenseVector => BDV, Vector => BV, norm => breezeNorm}
import breeze.linalg.{DenseVector => BDV, Vector => BV}

import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
Expand Down Expand Up @@ -125,7 +125,7 @@ class KMeans private (
}

// Compute squared norms and cache them.
val norms = data.map(v => breezeNorm(v.toBreeze, 2.0))
val norms = data.map(Vectors.norm(_, 2.0))
norms.persist()
val breezeData = data.map(_.toBreeze).zip(norms).map { case (v, norm) =>
new BreezeVectorWithNorm(v, norm)
Expand Down Expand Up @@ -425,7 +425,7 @@ object KMeans {
private[clustering]
class BreezeVectorWithNorm(val vector: BV[Double], val norm: Double) extends Serializable {

def this(vector: BV[Double]) = this(vector, breezeNorm(vector, 2.0))
def this(vector: BV[Double]) = this(vector, Vectors.norm(Vectors.fromBreeze(vector), 2.0))

def this(array: Array[Double]) = this(new BDV[Double](array))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.mllib.feature

import breeze.linalg.{norm => brzNorm}

import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}

Expand Down Expand Up @@ -47,7 +45,7 @@ class Normalizer(p: Double) extends VectorTransformer {
* @return normalized vector. If the norm of the input is zero, it will return the input vector.
*/
override def transform(vector: Vector): Vector = {
val norm = brzNorm(vector.toBreeze, p)
val norm = Vectors.norm(vector, p)

if (norm != 0.0) {
// For dense vector, we've to allocate new memory for new output vector.
Expand Down
51 changes: 51 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,57 @@ object Vectors {
sys.error("Unsupported Breeze vector type: " + v.getClass.getName)
}
}

/**
* Returns the p-norm of this vector.
* @param vector input vector.
* @param p norm.
* @return norm in L^p^ space.
*/
private[spark] def norm(vector: Vector, p: Double): Double = {
require(p >= 1.0)
val values = vector match {
case dv: DenseVector => dv.values
case sv: SparseVector => sv.values
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}
val size = values.size

if (p == 1) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about p match { ... here? with @switch to ensure it's just a lookup? should be faster even than ifs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, will try tomorrow. But I don't expect too much difference.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In bytecode, there is no direct switch operation. As a result, the swtich or pattern matching will be compiled into if statement in the bytecode. See the following example

  def fun1(p: Double) = {
    p match {
      case 1.0 => 1.0
      case 2.0 => 2.0
      case _ => p
    }
  }

  def fun2(p: Double) = {
    if (p == 1.0) 1.0
    else if (p == 2.0) 2.0
    else p
  }

will be compiled to

  // access flags 0x1
  public fun1(D)D
   L0
    LINENUMBER 145 L0
    DLOAD 1
    DSTORE 3
   L1
    LINENUMBER 146 L1
    DCONST_1
    DLOAD 3
    DCMPL
    IFNE L2
    DCONST_1
    DSTORE 5
    GOTO L3
   L2
    LINENUMBER 147 L2
   FRAME APPEND [D]
    LDC 2.0
    DLOAD 3
    DCMPL
    IFNE L4
    LDC 2.0
    DSTORE 5
    GOTO L3
   L4
    LINENUMBER 148 L4
   FRAME SAME
    DLOAD 1
    DSTORE 5
   L3
    LINENUMBER 145 L3
   FRAME APPEND [D]
    DLOAD 5
    DRETURN
   L5
    LOCALVARIABLE this Lorg/apache/spark/mllib/stat/Test$; L0 L5 0
    LOCALVARIABLE p D L0 L5 1
    MAXSTACK = 4
    MAXLOCALS = 7

  // access flags 0x1
  public fun2(D)D
   L0
    LINENUMBER 153 L0
    DLOAD 1
    DCONST_1
    DCMPL
    IFNE L1
    DCONST_1
    GOTO L2
   L1
    LINENUMBER 154 L1
   FRAME SAME
    DLOAD 1
    LDC 2.0
    DCMPL
    IFNE L3
    LDC 2.0
    GOTO L2
   L3
    LINENUMBER 155 L3
   FRAME SAME
    DLOAD 1
   L2
    LINENUMBER 153 L2
   FRAME SAME1 D
    DRETURN
   L4
    LOCALVARIABLE this Lorg/apache/spark/mllib/stat/Test$; L0 L4 0
    LOCALVARIABLE p D L0 L4 1
    MAXSTACK = 4
    MAXLOCALS = 3

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting tangent. What happens if you add @switch? http://www.scala-lang.org/api/current/index.html#scala.annotation.switch Bytecode should have instructions for switch statements that aren't just conditionals, like tableswitch: https://docs.oracle.com/javase/specs/jvms/se7/html/jvms-3.html#jvms-3.10

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ha~ It only works if I change type from Double to Int. See the oracle doc you referenced The Java Virtual Machine's tableswitch and lookupswitch instructions operate only on int data. Because operations on byte, char, or short values are internally promoted to int, a switch whose expression evaluates to one of those types is compiled as though it evaluated to type int.

With

  def fun1(p: Int) = {
    (p: @switch) match {
      case 1 => 1
      case 2 => 2
      case _ => p
    }
  }

I got

  public fun1(I)I
   L0
    LINENUMBER 147 L0
    ILOAD 1
    ISTORE 2
    ILOAD 2
    TABLESWITCH
      1: L1
      2: L2
      default: L3
   L3
    LINENUMBER 150 L3
   FRAME APPEND [I]
    ILOAD 1
    GOTO L4
   L2
    LINENUMBER 149 L2
   FRAME SAME
    ICONST_2
    GOTO L4
   L1
    LINENUMBER 148 L1
   FRAME SAME
    ICONST_1
   L4
    LINENUMBER 147 L4
   FRAME SAME1 I
    IRETURN
   L5
    LOCALVARIABLE this Lorg/apache/spark/mllib/stat/Test$; L0 L5 0
    LOCALVARIABLE p I L0 L5 1
    MAXSTACK = 1
    MAXLOCALS = 3

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is an interesting discussion ~ :) But maybe more people are familiar with the if ... else if ... else statement. And this is not on the critical path.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah. but even with @switch here, the code will not be optimized unless p has type of Int.

var sum = 0.0
var i = 0
while (i < size) {
sum += math.abs(values(i))
i += 1
}
sum
} else if (p == 2) {
var sum = 0.0
var i = 0
while (i < size) {
sum += values(i) * values(i)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a common case and tight loop, avoid the duplicated value lookup?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried that, and JVM will optimize it so no performance difference.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will try to see if scala generates the same bytecode tomorrow. Maybe scala compiler optimizes it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They will generate different bytecode, but I don't see the performance difference. Maybe DSTORE to store the value to local variable is more expensive than looking up twice. Or maybe JVM just optimizes it internally. I don't have preference.

sum += values(i) * values(i)

will translate to

   L24
    LINENUMBER 292 L24
    DLOAD 13
    ALOAD 4
    ILOAD 15
    DALOAD
    ALOAD 4
    ILOAD 15
    DALOAD
    DMUL
    DADD
    DSTORE 13

while

val value = values(i)
sum += value * value

will translate to

   L24
    LINENUMBER 292 L24
    ALOAD 4
    ILOAD 15
    DALOAD
    DSTORE 16
   L25
    LINENUMBER 293 L25
    DLOAD 13
    DLOAD 16
    DLOAD 16
    DMUL
    DADD
    DSTORE 13

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh OK, didn't mean to make you spend so much time investigating. It is not optimized I imagine but may be inconsequential even in a tight loop.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, what if you define var value outside the while loop?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mengxr Having var value outside will generate the same bytecode. It does make sense, since in the bytecode, it just stores the value back to stack, and no difference between two version.

@srowen I'm very curious about this myself. :P I'm using ASM Bytecode Outline plugin in intellij, https://plugins.jetbrains.com/plugin/5918 so I can generate the bytecode by just simple right click. We use it to find couple boxing/unboxing performance issue at Alpine, and it's very useful.

i += 1
}
math.sqrt(sum)
} else if (p == Double.PositiveInfinity) {
var max = 0.0
var i = 0
while (i < size) {
val value = math.abs(values(i))
if (value > max) max = value
i += 1
}
max
} else {
var sum = 0.0
var i = 0
while (i < size) {
sum += math.pow(math.abs(values(i)), p)
i += 1
}
math.pow(sum, 1.0 / p)
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import breeze.linalg.{DenseMatrix => BDM}
import org.scalatest.FunSuite

import org.apache.spark.SparkException
import org.apache.spark.mllib.util.TestingUtils._

class VectorsSuite extends FunSuite {

Expand Down Expand Up @@ -197,4 +198,27 @@ class VectorsSuite extends FunSuite {
assert(svMap.get(2) === Some(3.1))
assert(svMap.get(3) === Some(0.0))
}

test("vector p-norm") {
val dv = Vectors.dense(0.0, -1.2, 3.1, 0.0, -4.5, 1.9)
val sv = Vectors.sparse(6, Seq((1, -1.2), (2, 3.1), (3, 0.0), (4, -4.5), (5, 1.9)))

assert(Vectors.norm(dv, 1.0) ~== dv.toArray.foldLeft(0.0)((a, v) =>
a + math.abs(v)) relTol 1E-8)
assert(Vectors.norm(sv, 1.0) ~== sv.toArray.foldLeft(0.0)((a, v) =>
a + math.abs(v)) relTol 1E-8)

assert(Vectors.norm(dv, 2.0) ~== math.sqrt(dv.toArray.foldLeft(0.0)((a, v) =>
a + v * v)) relTol 1E-8)
assert(Vectors.norm(sv, 2.0) ~== math.sqrt(sv.toArray.foldLeft(0.0)((a, v) =>
a + v * v)) relTol 1E-8)

assert(Vectors.norm(dv, Double.PositiveInfinity) ~== dv.toArray.map(math.abs).max relTol 1E-8)
assert(Vectors.norm(sv, Double.PositiveInfinity) ~== sv.toArray.map(math.abs).max relTol 1E-8)

assert(Vectors.norm(dv, 3.7) ~== math.pow(dv.toArray.foldLeft(0.0)((a, v) =>
a + math.pow(math.abs(v), 3.7)), 1.0 / 3.7) relTol 1E-8)
assert(Vectors.norm(sv, 3.7) ~== math.pow(sv.toArray.foldLeft(0.0)((a, v) =>
a + math.pow(math.abs(v), 3.7)), 1.0 / 3.7) relTol 1E-8)
}
}