Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
d2ea80d
Integrate ColumnNode AST into Column.scala
hvanhovell Jul 24, 2024
e7a2a32
Add internally registered functions
hvanhovell Aug 7, 2024
a4e52f4
Move window to cool new API :)
hvanhovell Aug 7, 2024
dcde4d4
Improve Window
hvanhovell Aug 9, 2024
9acfdbe
Refactor ColumnNode API
hvanhovell Aug 9, 2024
6e31176
Support UDFs/UDAFs
hvanhovell Aug 9, 2024
ea07c58
Regular Fixes
hvanhovell Aug 13, 2024
73b1812
UDF Fixes
hvanhovell Aug 13, 2024
3e41a98
Add test for ColumnNode sql and normalize
hvanhovell Aug 13, 2024
8dcb381
Merge remote-tracking branch 'apache/master' into SPARK-49022
hvanhovell Aug 14, 2024
87a7b1e
wip
hvanhovell Aug 14, 2024
5fe4b18
Merge remote-tracking branch 'apache/master' into SPARK-49022
hvanhovell Aug 14, 2024
763a082
Fix pyspark issues
hvanhovell Aug 14, 2024
4244ef6
fixes
hvanhovell Aug 15, 2024
e573f7c
Fix Connect MiMa
hvanhovell Aug 15, 2024
4ba1d94
Fix docs
hvanhovell Aug 15, 2024
35467a9
Merge remote-tracking branch 'apache/master' into SPARK-49022
hvanhovell Aug 15, 2024
7318f60
style
hvanhovell Aug 15, 2024
c73ef8e
merge artifact
hvanhovell Aug 15, 2024
40afb9a
Merge branch 'SPARK-49022' into SPARK-49025
hvanhovell Aug 15, 2024
6f84348
Remove expr() in scala/java land
hvanhovell Aug 15, 2024
33a0c38
Remove Column.apply(Expression)
hvanhovell Aug 16, 2024
32d4138
Fix pyspark
hvanhovell Aug 16, 2024
e365d5a
style
hvanhovell Aug 16, 2024
7456fbf
Merge branch 'SPARK-49022' into SPARK-49025
hvanhovell Aug 16, 2024
b818d89
python typing
hvanhovell Aug 16, 2024
a15f49b
Merge branch 'SPARK-49022' into SPARK-49025
hvanhovell Aug 16, 2024
73796a7
fix fix fix
hvanhovell Aug 16, 2024
b4f9608
Code Review
hvanhovell Aug 16, 2024
2e92e06
Merge branch 'SPARK-49022' into SPARK-49025
hvanhovell Aug 17, 2024
5104c5d
Merge remote-tracking branch 'apache/master' into SPARK-49025
hvanhovell Aug 17, 2024
232b42a
Learn to properly merge conflicts...
hvanhovell Aug 17, 2024
df29113
fix import
hvanhovell Aug 17, 2024
7077ebd
Fix tests
hvanhovell Aug 17, 2024
e7440fc
Fix Scala Tests
hvanhovell Aug 18, 2024
ec51bd8
Fix Pandas UDFs
hvanhovell Aug 18, 2024
058a806
Not needed anymore
hvanhovell Aug 18, 2024
6d404e1
style
hvanhovell Aug 18, 2024
90855e2
oops
hvanhovell Aug 18, 2024
92a2f45
Merge remote-tracking branch 'apache/master' into SPARK-49025
hvanhovell Aug 19, 2024
69cb29d
Fix UDTF
hvanhovell Aug 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.jdk.CollectionConverters._

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.Column
import org.apache.spark.sql.internal.ExpressionUtils.{column, expression}


// scalastyle:off: object.name
Expand All @@ -41,7 +42,7 @@ object functions {
def from_avro(
data: Column,
jsonFormatSchema: String): Column = {
Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, Map.empty))
AvroDataToCatalyst(data, jsonFormatSchema, Map.empty)
}

/**
Expand All @@ -62,7 +63,7 @@ object functions {
data: Column,
jsonFormatSchema: String,
options: java.util.Map[String, String]): Column = {
Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, options.asScala.toMap))
AvroDataToCatalyst(data, jsonFormatSchema, options.asScala.toMap)
}

/**
Expand All @@ -74,7 +75,7 @@ object functions {
*/
@Experimental
def to_avro(data: Column): Column = {
Column(CatalystDataToAvro(data.expr, None))
CatalystDataToAvro(data, None)
}

/**
Expand All @@ -87,6 +88,6 @@ object functions {
*/
@Experimental
def to_avro(data: Column, jsonFormatSchema: String): Column = {
Column(CatalystDataToAvro(data.expr, Some(jsonFormatSchema)))
CatalystDataToAvro(data, Some(jsonFormatSchema))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,12 @@ object CheckConnectJvmClientCompatibility {
"org.apache.spark.sql.TypedColumn.expr"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.TypedColumn$"),

// ColumnNode conversions
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.Converter"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSession$Converter$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSession$RichColumn"),

// Datasource V2 partition transforms
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform$"),
Expand Down Expand Up @@ -433,6 +439,9 @@ object CheckConnectJvmClientCompatibility {
// SQLImplicits
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits.session"),

// Column API
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Column.expr"),

// Steaming API
ProblemFilters.exclude[MissingTypesProblem](
"org.apache.spark.sql.streaming.DataStreamWriter" // Client version extends Logging
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ import java.time.LocalDateTime
import java.util.Properties

import org.apache.spark.SparkException
import org.apache.spark.sql.{Column, DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.tags.DockerTest
Expand Down Expand Up @@ -303,7 +303,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
ArrayType(DecimalType(2, 2), true))
// Test write null values.
df.select(df.queryExecution.analyzed.output.map { a =>
Column(Literal.create(null, a.dataType)).as(a.name)
lit(null).cast(a.dataType).as(a.name)
}: _*).write.jdbc(jdbcUrl, "public.barcopy2", new Properties)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import scala.jdk.CollectionConverters._

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.Column
import org.apache.spark.sql.internal.ExpressionUtils.{column, expression}
import org.apache.spark.sql.protobuf.utils.ProtobufUtils

// scalastyle:off: object.name
Expand Down Expand Up @@ -66,15 +67,11 @@ object functions {
*/
@Experimental
def from_protobuf(
data: Column,
messageName: String,
binaryFileDescriptorSet: Array[Byte],
options: java.util.Map[String, String]): Column = {
Column(
ProtobufDataToCatalyst(
data.expr, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap
)
)
data: Column,
messageName: String,
binaryFileDescriptorSet: Array[Byte],
options: java.util.Map[String, String]): Column = {
ProtobufDataToCatalyst(data, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap)
}

/**
Expand All @@ -93,7 +90,7 @@ object functions {
@Experimental
def from_protobuf(data: Column, messageName: String, descFilePath: String): Column = {
val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath)
Column(ProtobufDataToCatalyst(data.expr, messageName, Some(fileContent)))
ProtobufDataToCatalyst(data, messageName, Some(fileContent))
}

/**
Expand All @@ -112,7 +109,7 @@ object functions {
@Experimental
def from_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte])
: Column = {
Column(ProtobufDataToCatalyst(data.expr, messageName, Some(binaryFileDescriptorSet)))
ProtobufDataToCatalyst(data, messageName, Some(binaryFileDescriptorSet))
}

/**
Expand All @@ -132,7 +129,7 @@ object functions {
*/
@Experimental
def from_protobuf(data: Column, messageClassName: String): Column = {
Column(ProtobufDataToCatalyst(data.expr, messageClassName))
ProtobufDataToCatalyst(data, messageClassName)
}

/**
Expand All @@ -156,7 +153,7 @@ object functions {
data: Column,
messageClassName: String,
options: java.util.Map[String, String]): Column = {
Column(ProtobufDataToCatalyst(data.expr, messageClassName, None, options.asScala.toMap))
ProtobufDataToCatalyst(data, messageClassName, None, options.asScala.toMap)
}

/**
Expand Down Expand Up @@ -194,7 +191,7 @@ object functions {
@Experimental
def to_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte])
: Column = {
Column(CatalystDataToProtobuf(data.expr, messageName, Some(binaryFileDescriptorSet)))
CatalystDataToProtobuf(data, messageName, Some(binaryFileDescriptorSet))
}
/**
* Converts a column into binary of protobuf format. The Protobuf definition is provided
Expand All @@ -216,9 +213,7 @@ object functions {
descFilePath: String,
options: java.util.Map[String, String]): Column = {
val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath)
Column(
CatalystDataToProtobuf(data.expr, messageName, Some(fileContent), options.asScala.toMap)
)
CatalystDataToProtobuf(data, messageName, Some(fileContent), options.asScala.toMap)
}

/**
Expand All @@ -242,11 +237,7 @@ object functions {
binaryFileDescriptorSet: Array[Byte],
options: java.util.Map[String, String]
): Column = {
Column(
CatalystDataToProtobuf(
data.expr, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap
)
)
CatalystDataToProtobuf(data, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap)
}

/**
Expand All @@ -266,7 +257,7 @@ object functions {
*/
@Experimental
def to_protobuf(data: Column, messageClassName: String): Column = {
Column(CatalystDataToProtobuf(data.expr, messageClassName))
CatalystDataToProtobuf(data, messageClassName)
}

/**
Expand All @@ -288,6 +279,6 @@ object functions {
@Experimental
def to_protobuf(data: Column, messageClassName: String, options: java.util.Map[String, String])
: Column = {
Column(CatalystDataToProtobuf(data.expr, messageClassName, None, options.asScala.toMap))
CatalystDataToProtobuf(data, messageClassName, None, options.asScala.toMap)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
}

val mappedOutputCols = inputColNames.zip(tds).map { case (colName, td) =>
dataset.col(colName).expr.dataType match {
SchemaUtils.getSchemaField(dataset.schema, colName).dataType match {
case DoubleType =>
when(!col(colName).isNaN && col(colName) > td, lit(1.0))
.otherwise(lit(0.0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,12 @@ class StringIndexer @Since("1.4.0") (
private def getSelectedCols(dataset: Dataset[_], inputCols: Seq[String]): Seq[Column] = {
inputCols.map { colName =>
val col = dataset.col(colName)
if (col.expr.dataType == StringType) {
col
} else {
// We don't count for NaN values. Because `StringIndexerAggregator` only processes strings,
// we replace NaNs with null in advance.
when(!isnan(col), col).cast(StringType)
}
// We don't count for NaN values. Because `StringIndexerAggregator` only processes strings,
// we replace NaNs with null in advance.
val fpTypes = Seq(DoubleType, FloatType).map(_.catalogString)
when(typeof(col).isin(fpTypes: _*) && isnan(col), lit(null))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The optimizer will simplify this into:

  • when(isNaN(col), null).otherwise(col).cast(StringType) for Float/Double
  • col for String
  • col.castS(StringType) for other datatypes.

.otherwise(col)
.cast(StringType)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,17 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
// Schema transformation.
val schema = dataset.schema

val vectorCols = $(inputCols).filter { c =>
dataset.col(c).expr.dataType match {
case _: VectorUDT => true
case _ => false
}
val inputColsWithField = $(inputCols).map { c =>
c -> SchemaUtils.getSchemaField(schema, c)
}

val vectorCols = inputColsWithField.collect {
case (c, field) if field.dataType.isInstanceOf[VectorUDT] => c
}
val vectorColsLengths = VectorAssembler.getLengths(
dataset, vectorCols.toImmutableArraySeq, $(handleInvalid))

val featureAttributesMap = $(inputCols).map { c =>
val field = SchemaUtils.getSchemaField(schema, c)
val featureAttributesMap = inputColsWithField.map { case (c, field) =>
field.dataType match {
case DoubleType =>
val attribute = Attribute.fromStructField(field)
Expand Down Expand Up @@ -144,8 +144,8 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
val assembleFunc = udf { r: Row =>
VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*)
}.asNondeterministic()
val args = $(inputCols).map { c =>
dataset(c).expr.dataType match {
val args = inputColsWithField.map { case (c, field) =>
field.dataType match {
case DoubleType => dataset(c)
case _: VectorUDT => dataset(c)
case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid")
Expand Down
10 changes: 4 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputT
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.ExpressionUtils.{column, expression}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -248,16 +249,13 @@ private[ml] class SummaryBuilderImpl(
) extends SummaryBuilder {

override def summary(featuresCol: Column, weightCol: Column): Column = {

val agg = SummaryBuilderImpl.MetricsAggregate(
SummaryBuilderImpl.MetricsAggregate(
requestedMetrics,
requestedCompMetrics,
featuresCol.expr,
weightCol.expr,
featuresCol,
weightCol,
mutableAggBufferOffset = 0,
inputAggBufferOffset = 0)

Column(agg.toAggregateExpression())
}
}

Expand Down
1 change: 1 addition & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ object MimaExcludes {
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.DataStreamWriter.clusterBy"),
// SPARK-49022: Use Column API
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.TypedColumn.this"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.TypedColumn.this"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.expressions.WindowSpec.this")
)

Expand Down
4 changes: 0 additions & 4 deletions python/pyspark/sql/classic/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,6 @@ def _to_java_column(col: "ColumnOrName") -> "JavaObject":
return jcol


def _to_java_expr(col: "ColumnOrName") -> "JavaObject":
return _to_java_column(col).expr()


@overload
def _to_seq(sc: "SparkContext", cols: Iterable["JavaObject"]) -> "JavaObject":
...
Expand Down
12 changes: 6 additions & 6 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def applyInPandas(
udf = pandas_udf(func, returnType=schema, functionType=PandasUDFType.GROUPED_MAP)
df = self._df
udf_column = udf(*[df[col] for col in df.columns])
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc)
return DataFrame(jdf, self.session)

def applyInPandasWithState(
Expand Down Expand Up @@ -356,7 +356,7 @@ def applyInPandasWithState(
df = self._df
udf_column = udf(*[df[col] for col in df.columns])
jdf = self._jgd.applyInPandasWithState(
udf_column._jc.expr(),
udf_column._jc,
self.session._jsparkSession.parseDataType(outputStructType.json()),
self.session._jsparkSession.parseDataType(stateStructType.json()),
outputMode,
Expand Down Expand Up @@ -523,7 +523,7 @@ def transformWithStateUDF(
udf_column = udf(*[df[col] for col in df.columns])

jdf = self._jgd.transformWithStateInPandas(
udf_column._jc.expr(),
udf_column._jc,
self.session._jsparkSession.parseDataType(outputStructType.json()),
outputMode,
timeMode,
Expand Down Expand Up @@ -653,7 +653,7 @@ def applyInArrow(
) # type: ignore[call-overload]
df = self._df
udf_column = udf(*[df[col] for col in df.columns])
jdf = self._jgd.flatMapGroupsInArrow(udf_column._jc.expr())
jdf = self._jgd.flatMapGroupsInArrow(udf_column._jc)
return DataFrame(jdf, self.session)

def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps":
Expand Down Expand Up @@ -793,7 +793,7 @@ def applyInPandas(

all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2)
udf_column = udf(*all_cols)
jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc.expr())
jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc)
return DataFrame(jdf, self._gd1.session)

def applyInArrow(
Expand Down Expand Up @@ -891,7 +891,7 @@ def applyInArrow(

all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2)
udf_column = udf(*all_cols)
jdf = self._gd1._jgd.flatMapCoGroupsInArrow(self._gd2._jgd, udf_column._jc.expr())
jdf = self._gd1._jgd.flatMapCoGroupsInArrow(self._gd2._jgd, udf_column._jc)
return DataFrame(jdf, self._gd1.session)

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/pandas/map_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def mapInPandas(
udf_column = udf(*[self[col] for col in self.columns])

jrp = self._build_java_profile(profile)
jdf = self._jdf.mapInPandas(udf_column._jc.expr(), barrier, jrp)
jdf = self._jdf.mapInPandas(udf_column._jc, barrier, jrp)
return DataFrame(jdf, self.sparkSession)

def mapInArrow(
Expand All @@ -75,7 +75,7 @@ def mapInArrow(
udf_column = udf(*[self[col] for col in self.columns])

jrp = self._build_java_profile(profile)
jdf = self._jdf.mapInArrow(udf_column._jc.expr(), barrier, jrp)
jdf = self._jdf.mapInArrow(udf_column._jc, barrier, jrp)
return DataFrame(jdf, self.sparkSession)

def _build_java_profile(
Expand Down
Loading