Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
685fd07
use UTF8String instead of String for StringType
Mar 31, 2015
21f67c6
cleanup
Mar 31, 2015
4699c3a
use Array[Byte] in UTF8String
Mar 31, 2015
d32abd1
fix utf8 for python api
Mar 31, 2015
a85fb27
refactor
Mar 31, 2015
6b499ac
fix style
Apr 1, 2015
5f9e120
fix sql tests
Apr 1, 2015
38c303e
fix python sql tests
Apr 1, 2015
c7dd4d2
fix some catalyst tests
Apr 1, 2015
bb52e44
fix scala style
Apr 1, 2015
8b45864
fix codegen with UTF8String
Apr 1, 2015
23a766c
refactor
Apr 1, 2015
9dc32d1
fix some hive tests
Apr 2, 2015
73e4363
Merge branch 'master' of github.com:apache/spark into string
Apr 2, 2015
956b0a4
fix hive tests
Apr 2, 2015
9f4c194
convert data type for data source
Apr 2, 2015
537631c
some comment about Date
Apr 2, 2015
28d6f32
refactor
Apr 2, 2015
28f3d81
Merge branch 'master' of github.com:apache/spark into string
Apr 3, 2015
e5fa5b8
remove clone in UTF8String
Apr 3, 2015
8d17f21
fix hive compatibility tests
Apr 3, 2015
fd11364
optimize UTF8String
Apr 3, 2015
ac18ae6
address comment
Apr 3, 2015
2089d24
add hashcode check back
Apr 3, 2015
13d9d42
Merge branch 'master' of github.com:apache/spark into string
Apr 3, 2015
867bf50
fix String filter push down
Apr 4, 2015
1314a37
address comments from Yin
Apr 8, 2015
5116b43
rollback unrelated changes
Apr 8, 2015
08d897b
Merge branch 'master' of github.com:apache/spark into string
Apr 9, 2015
b04a19c
add comment for getString/setString
Apr 10, 2015
744788f
Merge branch 'master' of github.com:apache/spark into string
Apr 13, 2015
341ec2c
turn off scala style check in UTF8StringSuite
Apr 13, 2015
59025c8
address comments from @marmbrus
Apr 15, 2015
6d776a9
Merge branch 'master' of github.com:apache/spark into string
Apr 15, 2015
2772f0d
fix new test failure
Apr 15, 2015
3b7bfa8
fix schema of AddJar
Apr 15, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def join(self, other, joinExprs=None, joinType=None):
One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.

>>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
[Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)]
[Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)]
"""

if joinExprs is None:
Expand Down Expand Up @@ -637,9 +637,9 @@ def groupBy(self, *cols):
>>> df.groupBy().avg().collect()
[Row(AVG(age)=3.5)]
>>> df.groupBy('name').agg({'age': 'mean'}).collect()
[Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)]
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
>>> df.groupBy(df.name).avg().collect()
[Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)]
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
"""
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
Expand Down Expand Up @@ -867,11 +867,11 @@ def agg(self, *exprs):

>>> gdf = df.groupBy(df.name)
>>> gdf.agg({"*": "count"}).collect()
[Row(name=u'Bob', COUNT(1)=1), Row(name=u'Alice', COUNT(1)=1)]
[Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)]

>>> from pyspark.sql import functions as F
>>> gdf.agg(F.min(df.age)).collect()
[Row(MIN(age)=5), Row(MIN(age)=2)]
[Row(MIN(age)=2), Row(MIN(age)=5)]
"""
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
Expand Down
3 changes: 2 additions & 1 deletion sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql
import scala.util.hashing.MurmurHash3

import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types.{StructType, DateUtils}
import org.apache.spark.sql.types.StructType

object Row {
/**
Expand Down Expand Up @@ -257,6 +257,7 @@ trait Row extends Serializable {
*
* @throws ClassCastException when data type does not match.
*/
// TODO(davies): This is not the right default implementation, we use Int as Date internally
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we excepting a Date object or its Int representation? For internal use, I guess we expect an int. But, since users can also call it, for them, a Date object is expected, right?
Right now, if I call getDate, I will get a ClassCastException?

Can you file a jira for it for 1.4 and mark it as a blocker?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date]

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ object CatalystTypeConverters {
}
new GenericRowWithSchema(ar, structType)

case (d: String, _) =>
UTF8String(d)

case (d: BigDecimal, _) =>
Decimal(d)

Expand Down Expand Up @@ -175,6 +178,11 @@ object CatalystTypeConverters {
case other => other
}

case dataType: StringType => (item: Any) => extractOption(item) match {
case s: String => UTF8String(s)
case other => other
}

case _ =>
(item: Any) => extractOption(item) match {
case d: BigDecimal => Decimal(d)
Expand All @@ -184,6 +192,26 @@ object CatalystTypeConverters {
}
}

/**
* Converts Scala objects to catalyst rows / types.
*
* Note: This should be called before do evaluation on Row
* (It does not support UDT)
* This is used to create an RDD or test results with correct types for Catalyst.
*/
def convertToCatalyst(a: Any): Any = a match {
case s: String => UTF8String(s)
case d: java.sql.Date => DateUtils.fromJavaDate(d)
case d: BigDecimal => Decimal(d)
case d: java.math.BigDecimal => Decimal(d)
case seq: Seq[Any] => seq.map(convertToCatalyst)
case r: Row => Row(r.toSeq.map(convertToCatalyst): _*)
case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray
case m: Map[Any, Any] =>
m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap
case other => other
}

/**
* Converts Catalyst types used internally in rows to standard Scala types
* This method is slow, and for batch conversion you should be using converter
Expand Down Expand Up @@ -211,6 +239,9 @@ object CatalystTypeConverters {
case (i: Int, DateType) =>
DateUtils.toJavaDate(i)

case (s: UTF8String, StringType) =>
s.toString()

case (other, _) =>
other
}
Expand Down Expand Up @@ -262,6 +293,12 @@ object CatalystTypeConverters {
case other => other
}

case StringType =>
(item: Any) => item match {
case s: UTF8String => s.toString()
case other => other
}

case other =>
(item: Any) => item
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ trait ScalaReflection {
// The data type can be determined without ambiguity.
case obj: BooleanType.JvmType => BooleanType
case obj: BinaryType.JvmType => BinaryType
case obj: String => StringType
case obj: StringType.JvmType => StringType
case obj: ByteType.JvmType => ByteType
case obj: ShortType.JvmType => ShortType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ trait HiveTypeCoercion {
* the appropriate numeric equivalent.
*/
object ConvertNaNs extends Rule[LogicalPlan] {
val stringNaN = Literal.create("NaN", StringType)
val stringNaN = Literal("NaN")

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressions {
Expand Down Expand Up @@ -563,6 +563,10 @@ trait HiveTypeCoercion {
case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType))
case Average(e @ TimestampType()) => Average(Cast(e, DoubleType))

// Compatible with Hive
case Substring(e, start, len) if e.dataType != StringType =>
Substring(Cast(e, StringType), start, len)

// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
// compatible with every child column.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.types._

/** Cast the child expression to the target data type. */
Expand Down Expand Up @@ -112,21 +111,21 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w

// UDFToString
private[this] def castToString(from: DataType): Any => Any = from match {
case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8"))
case DateType => buildCast[Int](_, d => DateUtils.toString(d))
case TimestampType => buildCast[Timestamp](_, timestampToString)
case _ => buildCast[Any](_, _.toString)
case BinaryType => buildCast[Array[Byte]](_, UTF8String(_))
case DateType => buildCast[Int](_, d => UTF8String(DateUtils.toString(d)))
case TimestampType => buildCast[Timestamp](_, t => UTF8String(timestampToString(t)))
case _ => buildCast[Any](_, o => UTF8String(o.toString))
}

// BinaryConverter
private[this] def castToBinary(from: DataType): Any => Any = from match {
case StringType => buildCast[String](_, _.getBytes("UTF-8"))
case StringType => buildCast[UTF8String](_, _.getBytes)
}

// UDFToBoolean
private[this] def castToBoolean(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, _.length() != 0)
buildCast[UTF8String](_, _.length() != 0)
case TimestampType =>
buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0)
case DateType =>
Expand All @@ -151,8 +150,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// TimestampConverter
private[this] def castToTimestamp(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => {
buildCast[UTF8String](_, utfs => {
// Throw away extra if more than 9 decimal places
val s = utfs.toString
val periodIdx = s.indexOf(".")
var n = s
if (periodIdx != -1 && n.length() - periodIdx > 9) {
Expand Down Expand Up @@ -227,8 +227,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// DateConverter
private[this] def castToDate(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s =>
try DateUtils.fromJavaDate(Date.valueOf(s))
buildCast[UTF8String](_, s =>
try DateUtils.fromJavaDate(Date.valueOf(s.toString))
catch { case _: java.lang.IllegalArgumentException => null }
)
case TimestampType =>
Expand All @@ -245,7 +245,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => try s.toLong catch {
buildCast[UTF8String](_, s => try s.toString.toLong catch {
case _: NumberFormatException => null
})
case BooleanType =>
Expand All @@ -261,7 +261,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// IntConverter
private[this] def castToInt(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => try s.toInt catch {
buildCast[UTF8String](_, s => try s.toString.toInt catch {
case _: NumberFormatException => null
})
case BooleanType =>
Expand All @@ -277,7 +277,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// ShortConverter
private[this] def castToShort(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => try s.toShort catch {
buildCast[UTF8String](_, s => try s.toString.toShort catch {
case _: NumberFormatException => null
})
case BooleanType =>
Expand All @@ -293,7 +293,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// ByteConverter
private[this] def castToByte(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => try s.toByte catch {
buildCast[UTF8String](_, s => try s.toString.toByte catch {
case _: NumberFormatException => null
})
case BooleanType =>
Expand Down Expand Up @@ -323,7 +323,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w

private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch {
buildCast[UTF8String](_, s => try {
changePrecision(Decimal(s.toString.toDouble), target)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite related to your change. But, why we convert the string to double first?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess Double is the wide format and range for numbers, or we need to have a special parser for it.

} catch {
case _: NumberFormatException => null
})
case BooleanType =>
Expand All @@ -348,7 +350,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// DoubleConverter
private[this] def castToDouble(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => try s.toDouble catch {
buildCast[UTF8String](_, s => try s.toString.toDouble catch {
case _: NumberFormatException => null
})
case BooleanType =>
Expand All @@ -364,7 +366,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// FloatConverter
private[this] def castToFloat(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => try s.toFloat catch {
buildCast[UTF8String](_, s => try s.toString.toFloat catch {
case _: NumberFormatException => null
})
case BooleanType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,17 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
new GenericRow(newValues)
}

override def update(ordinal: Int, value: Any): Unit = {
if (value == null) setNullAt(ordinal) else values(ordinal).update(value)
override def update(ordinal: Int, value: Any) {
if (value == null) {
setNullAt(ordinal)
} else {
values(ordinal).update(value)
}
}

override def setString(ordinal: Int, value: String): Unit = update(ordinal, value)
override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String(value))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we are still expecting a String as the input parameter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an API, so we should keep it.


override def getString(ordinal: Int): String = apply(ordinal).asInstanceOf[String]
override def getString(ordinal: Int): String = apply(ordinal).toString

override def setInt(ordinal: Int, value: Int): Unit = {
val currentValue = values(ordinal).asInstanceOf[MutableInt]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
val $primitiveTerm: ${termForType(dataType)} = $value
""".children

case expressions.Literal(value: String, dataType) =>
case expressions.Literal(value: UTF8String, dataType) =>
q"""
val $nullTerm = ${value == null}
val $primitiveTerm: ${termForType(dataType)} = $value
val $primitiveTerm: ${termForType(dataType)} =
org.apache.spark.sql.types.UTF8String(${value.getBytes})
""".children

case expressions.Literal(value: Int, dataType) =>
Expand All @@ -243,11 +244,14 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
if($nullTerm)
${defaultPrimitive(StringType)}
else
new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
""".children

case Cast(child @ DateType(), StringType) =>
child.castOrNull(c => q"org.apache.spark.sql.types.DateUtils.toString($c)", StringType)
child.castOrNull(c =>
q"""org.apache.spark.sql.types.UTF8String(
org.apache.spark.sql.types.DateUtils.toString($c))""",
StringType)

case Cast(child @ NumericType(), IntegerType) =>
child.castOrNull(c => q"$c.toInt", IntegerType)
Expand All @@ -272,9 +276,18 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
if($nullTerm)
${defaultPrimitive(StringType)}
else
${eval.primitiveTerm}.toString
org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString)
""".children

case EqualTo(e1: BinaryType, e2: BinaryType) =>
(e1, e2).evaluateAs (BooleanType) {
case (eval1, eval2) =>
q"""
java.util.Arrays.equals($eval1.asInstanceOf[Array[Byte]],
$eval2.asInstanceOf[Array[Byte]])
"""
}

case EqualTo(e1, e2) =>
(e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" }

Expand Down Expand Up @@ -597,7 +610,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
val localLogger = log
val localLoggerTree = reify { localLogger }
q"""
$localLoggerTree.debug(${e.toString} + ": " + (if($nullTerm) "null" else $primitiveTerm))
$localLoggerTree.debug(
${e.toString} + ": " + (if ($nullTerm) "null" else $primitiveTerm.toString))
""" :: Nil
} else {
Nil
Expand All @@ -608,6 +622,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin

protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = {
dataType match {
case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]"
case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)"
case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]"
}
Expand All @@ -619,6 +634,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
ordinal: Int,
value: TermName) = {
dataType match {
case StringType => q"$destinationRow.update($ordinal, $value)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems this one is needed because getString returns a String.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes.

case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
case _ => q"$destinationRow.update($ordinal, $value)"
}
Expand All @@ -642,13 +658,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
case DoubleType => "Double"
case FloatType => "Float"
case BooleanType => "Boolean"
case StringType => "String"
case StringType => "org.apache.spark.sql.types.UTF8String"
}

protected def defaultPrimitive(dt: DataType) = dt match {
case BooleanType => ru.Literal(Constant(false))
case FloatType => ru.Literal(Constant(-1.0.toFloat))
case StringType => ru.Literal(Constant("<uninit>"))
case StringType => q"""org.apache.spark.sql.types.UTF8String("<uninit>")"""
case ShortType => ru.Literal(Constant(-1.toShort))
case LongType => ru.Literal(Constant(-1L))
case ByteType => ru.Literal(Constant(-1.toByte))
Expand Down
Loading