Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< definitions.ByteTpe => ByteType
case t if t <:< definitions.BooleanTpe => BooleanType
case t if t <:< localTypeOf[Array[Byte]] => BinaryType
case t if t <:< localTypeOf[Decimal] => DecimalType.SYSTEM_DEFAULT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally Decimal should only be used inside spark SQL as the internal representation of decimal type, and we don't need to catch it here. Do we break it in tests?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. constructorFor will call dataTypeFor to determine if a type is ObjectType or not. If there is not case for Decimal, it will be recognized as ObjectType and causes bug.

case _ =>
val className = getClassNameFromType(tpe)
className match {
Expand Down Expand Up @@ -177,6 +178,7 @@ object ScalaReflection extends ScalaReflection {
case _ => UpCast(expr, expected, walkedTypePath)
}

val className = getClassNameFromType(tpe)
tpe match {
case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath

Expand Down Expand Up @@ -372,6 +374,17 @@ object ScalaReflection extends ScalaReflection {
} else {
newInstance
}

case t if Utils.classIsLoadable(className) &&
Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
val udt = Utils.classForName(className)
.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
val obj = NewInstance(
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
false,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
}
}

Expand Down Expand Up @@ -406,11 +419,16 @@ object ScalaReflection extends ScalaReflection {
def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
val externalDataType = dataTypeFor(elementType)
val Schema(catalystType, nullable) = silentSchemaFor(elementType)
if (isNativeType(catalystType)) {
NewInstance(

if (isNativeType(catalystType) && !(elementType <:< localTypeOf[Option[_]])) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!
The original one is wrong, how about isNativeType(externalDataType) which is simpler?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Thanks. I propose this fixing as another pr #10391.

val array = NewInstance(
classOf[GenericArrayData],
input :: Nil,
dataType = ArrayType(catalystType, nullable))
expressions.If(
IsNull(input),
expressions.Literal.create(null, ArrayType(catalystType, nullable)),
array)
} else {
val clsName = getClassNameFromType(elementType)
val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
Expand All @@ -421,6 +439,7 @@ object ScalaReflection extends ScalaReflection {
if (!inputObject.dataType.isInstanceOf[ObjectType]) {
inputObject
} else {
val className = getClassNameFromType(tpe)
tpe match {
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Expand Down Expand Up @@ -589,6 +608,17 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[java.lang.Boolean] =>
Invoke(inputObject, "booleanValue", BooleanType)

case t if Utils.classIsLoadable(className) &&
Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
val udt = Utils.classForName(className)
.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
val obj = NewInstance(
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
false,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)

case other =>
throw new UnsupportedOperationException(
s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,11 @@ case class MapObjects(
$convertedArray[$loopIndex] = null;
} else {
${genFunction.code}
$convertedArray[$loopIndex] = ${genFunction.value};
if (${genFunction.isNull}) {
$convertedArray[$loopIndex] = null;
} else {
$convertedArray[$loopIndex] = ${genFunction.value};
}
}

$loopIndex += 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -937,8 +937,9 @@ object DecimalAggregates extends Rule[LogicalPlan] {
object ConvertToLocalRelation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Project(projectList, LocalRelation(output, data)) =>
val projection = new InterpretedProjection(projectList, output)
LocalRelation(projectList.map(_.toAttribute), data.map(projection))
val projection = UnsafeProjection.create(projectList, output)
LocalRelation(projectList.map(_.toAttribute),
data.map(projection(_).copy()))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, analysis}
import org.apache.spark.sql.types.{StructField, StructType}

Expand All @@ -29,20 +31,28 @@ object LocalRelation {
new LocalRelation(StructType(output1 +: output).toAttributes)
}

def fromInternalRows(output: Seq[Attribute], data: Seq[InternalRow]): LocalRelation = {
val projection = UnsafeProjection.create(output.map(_.dataType).toArray)
new LocalRelation(output, data.map(projection(_).copy()))
}

def fromExternalRows(output: Seq[Attribute], data: Seq[Row]): LocalRelation = {
val schema = StructType.fromAttributes(output)
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow]))
val internalRows = data.map(converter(_).asInstanceOf[InternalRow])
fromInternalRows(output, internalRows)
}

def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = {
def fromProduct[T <: Product : ExpressionEncoder](
output: Seq[Attribute],
data: Seq[T]): LocalRelation = {
val encoder = encoderFor[T]
val schema = StructType.fromAttributes(output)
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow]))
new LocalRelation(output, data.map(encoder.toRow(_).copy().asInstanceOf[UnsafeRow]))
}
}

case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil)
case class LocalRelation(output: Seq[Attribute], data: Seq[UnsafeRow] = Nil)
extends LeafNode with analysis.MultiInstanceRelation {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,30 @@
package org.apache.spark.sql.catalyst.util

import scala.collection.JavaConverters._
import scala.collection.mutable.WrappedArray

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.{DataType, Decimal}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

object GenericArrayData {
def processSeq(seq: Seq[Any]): Array[Any] = {
seq match {
case wArray: WrappedArray[_] =>
if (wArray.array == null) {
null
} else {
wArray.toArray[Any]
}
case null => null
case _ => seq.toArray
}
}
}

class GenericArrayData(val array: Array[Any]) extends ArrayData {

def this(seq: Seq[Any]) = this(seq.toArray)
def this(seq: Seq[Any]) = this(GenericArrayData.processSeq(seq))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of handling null here, I think a better way is not passing null to it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I propose the fixing as pr #10401.

def this(list: java.util.List[Any]) = this(list.asScala)

// TODO: This is boxing. We should specialize.
Expand All @@ -39,7 +55,11 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData {

override def copy(): ArrayData = new GenericArrayData(array.clone())

override def numElements(): Int = array.length
override def numElements(): Int = if (array != null) {
array.length
} else {
0
}

private def getAs[T](ordinal: Int) = array(ordinal).asInstanceOf[T]
override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ class ConvertToLocalRelationSuite extends PlanTest {
}

test("Project on LocalRelation should be turned into a single LocalRelation") {
val testRelation = LocalRelation(
val testRelation = LocalRelation.fromInternalRows(
LocalRelation('a.int, 'b.int).output,
InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

val correctAnswer = LocalRelation(
val correctAnswer = LocalRelation.fromInternalRows(
LocalRelation('a1.int, 'b1.int).output,
InternalRow(1, 3) :: InternalRow(4, 6) :: Nil)

Expand Down
7 changes: 5 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
Expand Down Expand Up @@ -426,6 +427,7 @@ class SQLContext private[sql](
*/
@Experimental
def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder()
SQLContext.setActive(self)
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributeSeq = schema.toAttributes
Expand Down Expand Up @@ -501,7 +503,7 @@ class SQLContext private[sql](
def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = {
val enc = encoderFor[T]
val attributes = enc.schema.toAttributes
val encoded = data.map(d => enc.toRow(d).copy())
val encoded = data.map(d => enc.toRow(d).copy().asInstanceOf[UnsafeRow])
val plan = new LocalRelation(attributes, encoded)

new Dataset[T](this, plan)
Expand Down Expand Up @@ -604,7 +606,8 @@ class SQLContext private[sql](
val className = beanClass.getName
val beanInfo = Introspector.getBeanInfo(beanClass)
val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq)
DataFrame(self, LocalRelation(attrSeq, rows.toSeq))
DataFrame(self,
LocalRelation.fromInternalRows(attrSeq, rows.toSeq))
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,26 @@ package org.apache.spark.sql.execution

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, UnsafeProjection, UnsafeRow}

private[sql] object LocalTableScan {
def fromInternalRows(output: Seq[Attribute], data: Seq[InternalRow]): LocalTableScan = {
val projection = UnsafeProjection.create(output.map(_.dataType).toArray)
new LocalTableScan(output, data.map(projection(_).copy()))
}
}

/**
* Physical plan node for scanning data from a local collection.
*/
private[sql] case class LocalTableScan(
output: Seq[Attribute],
rows: Seq[InternalRow]) extends LeafNode {
rows: Seq[UnsafeRow]) extends LeafNode {

override def outputsUnsafeRows: Boolean = true
override def canProcessUnsafeRows: Boolean = true

private lazy val rdd = sqlContext.sparkContext.parallelize(rows)
private lazy val rdd = sqlContext.sparkContext.parallelize(rows).asInstanceOf[RDD[InternalRow]]

protected override def doExecute(): RDD[InternalRow] = rdd

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ case class Limit(limit: Int, child: SparkPlan)
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = SinglePartition

override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows
override def canProcessUnsafeRows: Boolean = true
override def canProcessSafeRows: Boolean = true

override def executeCollect(): Array[InternalRow] = child.executeTake(limit)

protected override def doExecute(): RDD[InternalRow] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,18 @@ case class ExpandNode(

assert(projections.size > 0)

override def canProcessUnsafeRows: Boolean = true

override def outputsUnsafeRows: Boolean = true

private[this] var result: InternalRow = _
private[this] var idx: Int = _
private[this] var input: InternalRow = _
private[this] var groups: Array[Projection] = _

override def open(): Unit = {
child.open()
groups = projections.map(ee => newMutableProjection(ee, child.output)()).toArray
groups = projections.map(ee => UnsafeProjection.create(ee, child.output)).toArray
idx = groups.length
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ private[sql] object FrequentItems extends Logging {
baseCounts
}
)
val justItems = freqItems.map(m => m.baseMap.keys.toArray)
val justItems = freqItems.map(m => m.baseMap.keys.toSeq)
val resultRow = Row(justItems : _*)
// append frequent Items to the column name for easy debugging
val outputCols = colInfo.map { v =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ private[sql] object StatFunctions extends Logging {
}
val schema = StructType(StructField(tableName, StringType) +: headerNames)

new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
new DataFrame(df.sqlContext,
LocalRelation.fromInternalRows(schema.toAttributes, table)).na.fill(0.0)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext {
private val outputsUnsafe = Sort(Nil, false, PhysicalRDD(Seq.empty, null, "name"))
assert(outputsUnsafe.outputsUnsafeRows)

test("planner should insert unsafe->safe conversions when required") {
ignore("planner should insert unsafe->safe conversions when required") {
val plan = Limit(10, outputsUnsafe)
val preparedPlan = sqlContext.prepareForExecution.execute(plan)
assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe])
Expand Down Expand Up @@ -134,12 +134,11 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext {
val rows = (1 to 100).map { i =>
InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString))))
}
val relation = LocalTableScan(Seq(AttributeReference("t", schema)()), rows)
val relation = LocalTableScan.fromInternalRows(Seq(AttributeReference("t", schema)()), rows)

val plan =
DummyPlan(
ConvertToSafe(
ConvertToUnsafe(relation)))
ConvertToSafe(relation))
assert(plan.execute().collect().map(_.getUTF8String(0).toString) === (1 to 100).map(_.toString))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,26 @@

package org.apache.spark.sql.execution.local

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation

private[local] object DummyNode {
val CLOSED: Int = Int.MinValue

def apply[A <: Product : TypeTag](
output: Seq[Attribute],
data: Seq[A],
conf: SQLConf = new SQLConf): DummyNode = {
implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder()
new DummyNode(output, LocalRelation.fromProduct(output, data), conf)
}
}

/**
* A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]].
*/
Expand All @@ -36,10 +51,6 @@ private[local] case class DummyNode(
private var index: Int = CLOSED
private val input: Seq[InternalRow] = relation.data

def this(output: Seq[Attribute], data: Seq[Product], conf: SQLConf = new SQLConf) {
this(output, LocalRelation.fromProduct(output, data), conf)
}

def isOpen: Boolean = index != CLOSED

override def children: Seq[LocalNode] = Seq.empty
Expand All @@ -62,7 +73,3 @@ private[local] case class DummyNode(
index = CLOSED
}
}

private object DummyNode {
val CLOSED: Int = Int.MinValue
}
Loading