diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GeneratedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GeneratedRow.scala index 92f3f46ec8920..0bc8a8715e7fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GeneratedRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GeneratedRow.scala @@ -128,22 +128,11 @@ class CodeGenerator extends Logging { ${getColumn(inputTuple, b.dataType, ordinal)} """.children - case expressions.Literal(value: String, dataType) => + case expressions.Literal(value, dataType) => q""" val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value + val $primitiveTerm: ${termForType(dataType)} = ${value.toString} """.children - case expressions.Literal(value: Int, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value - """.children - case expressions.Literal(value: Long, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value - """.children - case Cast(e, StringType) => val eval = expressionEvaluator(e) eval.code ++ @@ -188,19 +177,55 @@ class CodeGenerator extends Logging { $primitiveTerm = true } """.children + case Or(e1, e2) => + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = false + + if ((!${eval1.nullTerm} && ${eval1.primitiveTerm}) || + (!${eval2.nullTerm} && ${eval2.primitiveTerm})) { + $nullTerm = false + $primitiveTerm = true + } else if (${eval1.nullTerm} || ${eval2.nullTerm} ) { + $nullTerm = true + } else { + $nullTerm = false + $primitiveTerm = false + } + """.children case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" } case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" } case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" } case Divide(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 / $eval2" } + case Remainder(e1, e2) =>(e1, e2) evaluate { case (eval1, eval2) => q"$eval1 % $eval2" } + + case UnaryMinus(e) => + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + val $nullTerm = ${eval.nullTerm} + val $primitiveTerm: ${termForType(e.dataType)} = -${eval.primitiveTerm} + """.children case IsNotNull(e) => val eval = expressionEvaluator(e) q""" ..${eval.code} var $nullTerm = false - var $primitiveTerm: ${termForType(BooleanType)} = !${eval.nullTerm} + var $primitiveTerm: ${termForType(BooleanType)} = ${eval.nullTerm}.unary_! + """.children + + case IsNull(e) => + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = ${eval.nullTerm} """.children case c @ Coalesce(children) => @@ -221,6 +246,17 @@ class CodeGenerator extends Logging { """ } + case Not(e) => + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = ${eval.nullTerm} + var $primitiveTerm: ${termForType(BooleanType)} = ${eval.primitiveTerm}.unary_! + """.children + + // TODO transform the In to If +// case In(v, list) => + case i @ expressions.If(condition, trueValue, falseValue) => val condEval = expressionEvaluator(condition) val trueEval = expressionEvaluator(trueValue) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 6e585236b1b20..a5ab99f208bb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -19,11 +19,102 @@ package org.apache.spark.sql package catalyst package expressions +import scala.util.matching.Regex + +import catalyst.types.StringType import catalyst.types.BooleanType +import analysis.UnresolvedException +import catalyst.errors.`package`.TreeNodeException + + +abstract class BinaryString extends BinaryExpression { + self: Product => + + type EvaluatedType = Any + + def nullable = left.nullable || right.nullable + + override lazy val resolved = + left.resolved && right.resolved && left.dataType == StringType && right.dataType == StringType + + def dataType = { + if (!resolved) { + throw new UnresolvedException(this, + s"datatype. Can not resolve due to non string types ${left.dataType}, ${right.dataType}") + } + + BooleanType + } + + @inline + protected final def s2( + i: Row, + e1: Expression, + e2: Expression, + f: ((String, String) => Boolean)): Any = { + + if (e1.dataType != StringType) { + throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != StringType") + } -case class Like(left: Expression, right: Expression) extends BinaryExpression { - def dataType = BooleanType - def nullable = left.nullable // Right cannot be null. + if (e2.dataType != StringType) { + throw new TreeNodeException(this, s"Types do not match ${e2.dataType} != StringType") + } + + val evalE1 = e1.apply(i) + if(evalE1 == null) { + null + } else { + val evalE2 = e2.apply(i) + if (evalE2 == null) { + null + } else { + f.apply(evalE1.asInstanceOf[String], evalE1.asInstanceOf[String]) + } + } + } + + @inline + protected final def s1( + i: Row, + e1: Expression, + f: ((String) => Boolean)): Any = { + + if (e1.dataType != StringType) { + throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != StringType") + } + + val evalE1 = e1.apply(i) + if(evalE1 == null) { + null + } else { + f.apply(evalE1.asInstanceOf[String]) + } + } +} + +case class Like(left: Expression, right: Literal) extends BinaryString { def symbol = "LIKE" + // replace the _ with .{1} exactly match 1 time of any character + // replace the % with .*, match 0 or more times with any character + def regex(v: String) = v.replaceAll("_", ".{1}").replaceAll("%", ".*") + lazy val r = regex(right.value.asInstanceOf[String]).r + + override def apply(input: Row): Any = if(right.value == null) { + null + } else { + s1(input, left, r.findFirstIn(_) != None) + } } +case class RLike(left: Expression, right: Literal) extends BinaryString { + def symbol = "RLIKE" + + lazy val r = right.value.asInstanceOf[String].r + + override def apply(input: Row): Any = if(right.value == null) { + null + } else { + s1(input, left, r.findFirstIn(_) != None) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CodeGenerationSuite.scala deleted file mode 100644 index d11ff0dd0238d..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CodeGenerationSuite.scala +++ /dev/null @@ -1,68 +0,0 @@ -package org.apache.spark.sql -package catalyst -package expressions - -import org.scalatest.FunSuite - -import types._ -import expressions._ - -import dsl.expressions._ - -class CodeGenerationSuite extends FunSuite { - - val data = Array.fill(5)(new GenericRow(Array(1, null, 1.0))) - - // TODO add to DSL - val c1 = BoundReference(0, AttributeReference("a", IntegerType)()) - val c2 = BoundReference(1, AttributeReference("b", IntegerType)()) - val c3 = BoundReference(2, AttributeReference("c", DoubleType)()) - - test("simple") { - val generator = - GenerateProjection(Array(c1, c2, c3, Add(c1, c1), Add(c1, c2), Add(c2, c2), Subtract(c1, c2))) - val generatedRow = generator(data.head) - - // TODO: Factor out or use :javap? - val generatedClass = generatedRow.getClass - val classLoader = - generatedClass - .getClassLoader - .asInstanceOf[scala.tools.nsc.interpreter.AbstractFileClassLoader] - val generatedBytes = classLoader.classBytes(generatedClass.getName) - - val outfile = new java.io.FileOutputStream("generated.class") - outfile.write(generatedBytes) - outfile.close() - - println(generatedRow.length) - println(generatedRow) - println(generatedRow.getInt(0)) - println(generatedRow.isNullAt(1)) - - } - - test("ordering") { - val ordering = GenerateOrdering(Seq(c1, c2, Subtract(c1, c1)).map(_.asc)) - - val generatedClass = ordering.getClass - val classLoader = - generatedClass - .getClassLoader - .asInstanceOf[scala.tools.nsc.interpreter.AbstractFileClassLoader] - val generatedBytes = classLoader.classBytes(generatedClass.getName) - val outfile = new java.io.FileOutputStream("ordering.class") - outfile.write(generatedBytes) - outfile.close() - data.toArray.sorted(ordering) - } - - /* TODO: Just test serialization. - test("in rdd") { - val c1 = BoundReference(0, AttributeReference("a", IntegerType)()) - val projection = GenerateRow(c1 :: Nil) - val rdd = TestSqlContext.sc.makeRDD((1 to 1000).map(i => new GenericRow(i :: Nil))) - rdd.mapPartitions(projection).collect() - } - */ -} \ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ExpressionEvaluationSuite.scala index 0483a059bfe6e..192b1d59fe97c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ExpressionEvaluationSuite.scala @@ -20,15 +20,170 @@ package catalyst package expressions import org.scalatest.FunSuite - import types._ import expressions._ - import dsl._ import dsl.expressions._ + +abstract class ExprEval(exprs: Array[Expression]) { + type Execution = (Row => Row) + + def engine: Execution + + private val compare: PartialFunction[(Int, (Boolean, Any), Row), Boolean] = { + case (idx, (false, field), row) => { + (row.isNullAt(idx) == false) && (row.apply(idx) == field) + } + case (idx, (true, _), row) => { + row.isNullAt(idx) + } + } + + def verify(expected: Array[(Boolean, Any)], result: Row) { + Seq.tabulate(expected.size) { i => + assert(compare.lift(i, expected(i), result).get) + } + } + + def verify(expected: Array[Array[(Boolean, Any)]], inputs: Array[Row]) { + val result = inputs.map(engine.apply(_)) + + expected.zip(result).foreach { case (expectedRow, row) => verify(expectedRow, row) } + } +} + +case class CGExprEval(exprs: Array[Expression]) extends ExprEval(exprs) { + override def engine: Execution = GenerateProjection(exprs) +} + +case class InterpretEngine(exprs: Array[Expression]) extends ExprEval(exprs) { + override def engine: Execution = new InterpretedProjection(exprs) +} + class ExpressionEvaluationSuite extends FunSuite { + val data = Array.fill[Row](5)(new GenericRow(Array(1, null, 1.0, true, 4, 5, null, "abcccd"))) + + // TODO add to DSL + val c1 = BoundReference(0, AttributeReference("a", IntegerType)()) + val c2 = BoundReference(1, AttributeReference("b", IntegerType)()) + val c3 = BoundReference(2, AttributeReference("c", DoubleType)()) + val c4 = BoundReference(3, AttributeReference("d", BooleanType)()) + val c5 = BoundReference(4, AttributeReference("e", IntegerType)()) + val c6 = BoundReference(5, AttributeReference("f", IntegerType)()) + val c7 = BoundReference(6, AttributeReference("g", StringType)()) + val c8 = BoundReference(7, AttributeReference("h", StringType)()) + + test("simple") { + val generator = + GenerateProjection(Array(c1, c2, c3, Add(c1, c1), Add(c1, c2), Add(c2, c2), Subtract(c1, c2))) + val generatedRow = generator(data.head) + // TODO: Factor out or use :javap? + val generatedClass = generatedRow.getClass + val classLoader = + generatedClass + .getClassLoader + .asInstanceOf[scala.tools.nsc.interpreter.AbstractFileClassLoader] + val generatedBytes = classLoader.classBytes(generatedClass.getName) + + val outfile = new java.io.FileOutputStream("generated.class") + outfile.write(generatedBytes) + outfile.close() + + println(generatedRow.length) + println(generatedRow) + println(generatedRow.getInt(0)) + println(generatedRow.isNullAt(1)) + } + + def verify(exprs: Array[Expression], expecteds: Array[Array[(Boolean, Any)]], input: Array[Row]) { + val eval1 = CGExprEval(exprs) + val eval2 = InterpretEngine(exprs) + + eval1.verify(expecteds, input) + eval2.verify(expecteds, input) + } + + test("logical") { + val expecteds = Array.fill(5)(Array[(Boolean, Any)]( + (false, false), + (true, -1), + (false, true), + (false, true), + (false, false))) + val exprs = Array[Expression](And(LessThan(Cast(c1, DoubleType), c3), LessThan(c1, c2)), + Or(LessThan(Cast(c1, DoubleType), c3), LessThan(c1, c2)), + IsNull(c2), + IsNotNull(c3), + Not(c4)) + + verify(exprs, expecteds, data) + } + + test("arithmetic") { + val exprs = Array[Expression]( + Add(c1, c2), + Add(c1, c5), + Divide(c1, c5), + Subtract(c1, c5), + Multiply(c1, c5), + Remainder(c1, c5), + UnaryMinus(c1) + ) + val data = Array.fill[Row](5)(new GenericRow(Array(1, null, 1.0, true, 4, 5))) + val expecteds = Array.fill(5)(Array[(Boolean, Any)]( + (true, 0), + (false, 5), + (false, 0), + (false, -3), + (false, 4), + (false, 1), + (false, -1))) + + verify(exprs, expecteds, data) + } + + test("string like / rlike") { + val exprs = Array[Expression]( + Like(c7, Literal("a", StringType)), + Like(c7, Literal(null, StringType)), + Like(c8, Literal(null, StringType)), + Like(c8, Literal("a_c", StringType)), + Like(c8, Literal("a%c", StringType)), + RLike(c7, Literal("a+", StringType)), + RLike(c7, Literal(null, StringType)), + RLike(c8, Literal(null, StringType)), + RLike(c8, Literal("a%c", StringType)) + ) + + val expecteds = Array.fill(data.length)(Array[(Boolean, Any)]( + (true, false), + (true, false), + (true, false), + (false, true), + (false, true), + (true, false), + (true, false), + (true, true))) + verify(exprs, expecteds, data) + } + + test("ordering") { + val ordering = GenerateOrdering(Seq(c1, c2, Subtract(c1, c1)).map(_.asc)) + + val generatedClass = ordering.getClass + val classLoader = + generatedClass + .getClassLoader + .asInstanceOf[scala.tools.nsc.interpreter.AbstractFileClassLoader] + val generatedBytes = classLoader.classBytes(generatedClass.getName) + val outfile = new java.io.FileOutputStream("ordering.class") + outfile.write(generatedBytes) + outfile.close() + data.toArray.sorted(ordering) + } + test("literals") { assert((Literal(1) + Literal(1)).apply(null) === 2) } @@ -54,22 +209,21 @@ class ExpressionEvaluationSuite extends FunSuite { * Unknown Unknown */ - val notTrueTable = - (true, false) :: - (false, true) :: - (null, null) :: Nil - + val b1 = BoundReference(0, AttributeReference("a", BooleanType)()) + val b2 = BoundReference(1, AttributeReference("b", BooleanType)()) + test("3VL Not") { - notTrueTable.foreach { - case (v, answer) => - val expr = Not(Literal(v, BooleanType)) - val result = expr.apply(null) - if (result != answer) - fail(s"$expr should not evaluate to $result, expected: $answer") } + val table = (true, false) :: (false, true) :: (null, null) :: Nil + + val exprs = Array[Expression](Not(b1)) + val inputs = table.map { case(v, answer) => new GenericRow(Array(v)) } + val expected = table.map { case(v, answer) => Array((answer == null, answer)) } + + verify(exprs, expected.toArray, inputs.toArray) } - booleanLogicTest("AND", _ && _, - (true, true, true) :: + test("3VL AND") { + val table = (true, true, true) :: (true, false, false) :: (true, null, null) :: (false, true, false) :: @@ -77,10 +231,17 @@ class ExpressionEvaluationSuite extends FunSuite { (false, null, false) :: (null, true, null) :: (null, false, false) :: - (null, null, null) :: Nil) + (null, null, null) :: Nil + + val exprs = Array[Expression](And(b1, b2)) + val inputs = table.map { case(v1, v2, answer) => new GenericRow(Array(v1, v2)) } + val expected = table.map { case(v1, v2, answer) => Array((answer == null, answer)) } + + verify(exprs, expected.toArray, inputs.toArray) + } - booleanLogicTest("OR", _ || _, - (true, true, true) :: + test("3VL OR") { + val table = (true, true, true) :: (true, false, true) :: (true, null, true) :: (false, true, true) :: @@ -88,10 +249,17 @@ class ExpressionEvaluationSuite extends FunSuite { (false, null, null) :: (null, true, true) :: (null, false, null) :: - (null, null, null) :: Nil) - - booleanLogicTest("=", _ === _, - (true, true, true) :: + (null, null, null) :: Nil + + val exprs = Array[Expression](Or(b1, b2)) + val inputs = table.map { case(v1, v2, answer) => new GenericRow(Array(v1, v2)) } + val expected = table.map { case(v1, v2, answer) => Array((answer == null, answer)) } + + verify(exprs, expected.toArray, inputs.toArray) + } + + test("3VL Equals") { + val table = (true, true, true) :: (true, false, false) :: (true, null, null) :: (false, true, false) :: @@ -99,17 +267,12 @@ class ExpressionEvaluationSuite extends FunSuite { (false, null, null) :: (null, true, null) :: (null, false, null) :: - (null, null, null) :: Nil) - - def booleanLogicTest(name: String, op: (Expression, Expression) => Expression, truthTable: Seq[(Any, Any, Any)]) { - test(s"3VL $name") { - truthTable.foreach { - case (l,r,answer) => - val expr = op(Literal(l, BooleanType), Literal(r, BooleanType)) - val result = expr.apply(null) - if (result != answer) - fail(s"$expr should not evaluate to $result, expected: $answer") - } - } + (null, null, null) :: Nil + + val exprs = Array[Expression](Equals(b1, b2)) + val inputs = table.map { case(v1, v2, answer) => new GenericRow(Array(v1, v2)) } + val expected = table.map { case(v1, v2, answer) => Array((answer == null, answer)) } + + verify(exprs, expected.toArray, inputs.toArray) } } \ No newline at end of file