Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
SparkPlan.currentContext.set(self)
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributeSeq = schema.toAttributes
val rowRDD = RDDConversions.productToRowRdd(rdd, schema)
val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType))
DataFrame(self, LogicalRDD(attributeSeq, rowRDD)(self))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,63 +21,49 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, SpecificMutableRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.{Row, SQLContext}

/**
* :: DeveloperApi ::
*/
@DeveloperApi
object RDDConversions {
def productToRowRdd[A <: Product](data: RDD[A], schema: StructType): RDD[Row] = {
def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[Row] = {
data.mapPartitions { iterator =>
if (iterator.isEmpty) {
Iterator.empty
} else {
val bufferedIterator = iterator.buffered
val mutableRow = new SpecificMutableRow(schema.fields.map(_.dataType))
val schemaFields = schema.fields.toArray
val converters = schemaFields.map {
f => CatalystTypeConverters.createToCatalystConverter(f.dataType)
}
bufferedIterator.map { r =>
var i = 0
while (i < mutableRow.length) {
mutableRow(i) = converters(i)(r.productElement(i))
i += 1
}

mutableRow
val numColumns = outputTypes.length
val mutableRow = new GenericMutableRow(numColumns)
val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter)
iterator.map { r =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a while loop will be better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually the return value of the mapPartitions call, which must be a Scala iterator. Note that we do use a while loop to iterate over the columns.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I am fine with the current version. We are not slower than the previous version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably faster than the older version since it doesn't use an unnecessary bufferedIterator :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I also realized that after I made my comment.

var i = 0
while (i < numColumns) {
mutableRow(i) = converters(i)(r.productElement(i))
i += 1
}

mutableRow
}
}
}

/**
* Convert the objects inside Row into the types Catalyst expected.
*/
def rowToRowRdd(data: RDD[Row], schema: StructType): RDD[Row] = {
def rowToRowRdd(data: RDD[Row], outputTypes: Seq[DataType]): RDD[Row] = {
data.mapPartitions { iterator =>
if (iterator.isEmpty) {
Iterator.empty
} else {
val bufferedIterator = iterator.buffered
val mutableRow = new GenericMutableRow(bufferedIterator.head.toSeq.toArray)
val schemaFields = schema.fields.toArray
val converters = schemaFields.map {
f => CatalystTypeConverters.createToCatalystConverter(f.dataType)
}
bufferedIterator.map { r =>
var i = 0
while (i < mutableRow.length) {
mutableRow(i) = converters(i)(r(i))
i += 1
}

mutableRow
val numColumns = outputTypes.length
val mutableRow = new GenericMutableRow(numColumns)
val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter)
iterator.map { r =>
var i = 0
while (i < numColumns) {
mutableRow(i) = converters(i)(r(i))
i += 1
}

mutableRow
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
output: Seq[Attribute],
rdd: RDD[Row]): SparkPlan = {
val converted = if (relation.needConversion) {
execution.RDDConversions.rowToRowRdd(rdd, relation.schema)
execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType))
} else {
rdd
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputForma
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.{Row, SQLContext}
Expand Down Expand Up @@ -108,7 +109,10 @@ class SimpleTextRelation(

sparkContext.textFile(inputStatuses.map(_.getPath).mkString(",")).map { record =>
Row(record.split(",").zip(fields).map { case (value, dataType) =>
Cast(Literal(value), dataType).eval()
// `Cast`ed values are always of Catalyst types (i.e. UTF8String instead of String, etc.)
val catalystValue = Cast(Literal(value), dataType).eval()
// Here we're converting Catalyst values to Scala values to test `needsConversion`
CatalystTypeConverters.convertToScala(catalystValue, dataType)
}: _*)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
df.filter('a > 1 && 'p1 < 2).select('b, 'p1),
for (i <- 2 to 3; _ <- Seq("foo", "bar")) yield Row(s"val_$i", 1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JoshRosen Found that I made a mistake in this checkAnswer test, which should have caught this bug... Values of p1 should be either "foo" or "bar", but I put a 1 there. The correct version of this test should be:

    // Simple projection and partition pruning
    checkAnswer(
      df.filter('a > 1 && 'p1 < 2).select('b, 'p1),
      for (i <- 2 to 3; j <- Seq("foo", "bar")) yield Row(s"val_$i", j))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure that p1 is string-typed? I think that p2 is, but there's other code which implies that p1 should be an int.

e.g. at the top of HadoopFsRelationTest:

  val partitionedTestDF1 = (for {
    i <- 1 to 3
    p2 <- Seq("foo", "bar")
  } yield (i, s"val_$i", 1, p2)).toDF("a", "b", "p1", "p2")

  val partitionedTestDF2 = (for {
    i <- 1 to 3
    p2 <- Seq("foo", "bar")
  } yield (i, s"val_$i", 2, p2)).toDF("a", "b", "p1", "p2")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think the type of p1 is int.

cc @liancheng

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry my bad.


// Project many copies of columns with different types (reproduction for SPARK-7858)
checkAnswer(
df.filter('a > 1 && 'p1 < 2).select('b, 'b, 'b, 'b, 'p1, 'p1, 'p1, 'p1),
for (i <- 2 to 3; _ <- Seq("foo", "bar"))
yield Row(s"val_$i", s"val_$i", s"val_$i", s"val_$i", 1, 1, 1, 1))

// Self-join
df.registerTempTable("t")
withTempTable("t") {
Expand Down