Skip to content

Commit 016e8a7

Browse files
Shitinemccarthy
authored andcommitted
[SPARK-7295][SQL] bitwise operations for DataFrame DSL
Author: Shiti <[email protected]> Closes apache#5867 from Shiti/spark-7295 and squashes the following commits: 71a9913 [Shiti] implementation for bitwise and,or, not and xor on Column with tests and docs
1 parent f86b8e5 commit 016e8a7

File tree

7 files changed

+97
-2
lines changed

7 files changed

+97
-2
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,11 @@ def __init__(self, jc):
12771277
__contains__ = _bin_op("contains")
12781278
__getitem__ = _bin_op("getItem")
12791279

1280+
# bitwise operators
1281+
bitwiseOR = _bin_op("bitwiseOR")
1282+
bitwiseAND = _bin_op("bitwiseAND")
1283+
bitwiseXOR = _bin_op("bitwiseXOR")
1284+
12801285
def getItem(self, key):
12811286
"""An expression that gets an item at position `ordinal` out of a list,
12821287
or gets an item by key out of a dict.

python/pyspark/sql/functions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def _(col1, col2):
104104
'toRadians': 'Converts an angle measured in degrees to an approximately equivalent angle ' +
105105
'measured in radians.',
106106

107+
'bitwiseNOT': 'Computes bitwise not.',
108+
107109
'max': 'Aggregate function: returns the maximum value of the expression in a group.',
108110
'min': 'Aggregate function: returns the minimum value of the expression in a group.',
109111
'first': 'Aggregate function: returns the first value in a group.',

python/pyspark/sql/tests.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,19 @@ def test_fillna(self):
645645
self.assertEqual(row.age, None)
646646
self.assertEqual(row.height, None)
647647

648+
def test_bitwise_operations(self):
649+
from pyspark.sql import functions
650+
row = Row(a=170, b=75)
651+
df = self.sqlCtx.createDataFrame([row])
652+
result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict()
653+
self.assertEqual(170 & 75, result['(a & b)'])
654+
result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict()
655+
self.assertEqual(170 | 75, result['(a | b)'])
656+
result = df.select(df.a.bitwiseXOR(df.b)).collect()[0].asDict()
657+
self.assertEqual(170 ^ 75, result['(a ^ b)'])
658+
result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
659+
self.assertEqual(~75, result['~b'])
660+
648661

649662
class HiveContextSQLTests(ReusedPySparkTestCase):
650663

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,37 @@ class Column(protected[sql] val expr: Expression) extends Logging {
698698
println(expr.prettyString)
699699
}
700700
}
701+
702+
/**
703+
* Compute bitwise OR of this expression with another expression.
704+
* {{{
705+
* df.select($"colA".bitwiseOR($"colB"))
706+
* }}}
707+
*
708+
* @group expr_ops
709+
*/
710+
def bitwiseOR(other: Any): Column = BitwiseOr(expr, lit(other).expr)
711+
712+
/**
713+
* Compute bitwise AND of this expression with another expression.
714+
* {{{
715+
* df.select($"colA".bitwiseAND($"colB"))
716+
* }}}
717+
*
718+
* @group expr_ops
719+
*/
720+
def bitwiseAND(other: Any): Column = BitwiseAnd(expr, lit(other).expr)
721+
722+
/**
723+
* Compute bitwise XOR of this expression with another expression.
724+
* {{{
725+
* df.select($"colA".bitwiseXOR($"colB"))
726+
* }}}
727+
*
728+
* @group expr_ops
729+
*/
730+
def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr)
731+
701732
}
702733

703734

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,14 @@ object functions {
438438
*/
439439
def upper(e: Column): Column = Upper(e.expr)
440440

441+
442+
/**
443+
* Computes bitwise NOT.
444+
*
445+
* @group normal_funcs
446+
*/
447+
def bitwiseNOT(e: Column): Column = BitwiseNot(e.expr)
448+
441449
//////////////////////////////////////////////////////////////////////////////////////////////
442450
// Math Functions
443451
//////////////////////////////////////////////////////////////////////////////////////////////

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ import org.apache.spark.sql.types._
2727
class ColumnExpressionSuite extends QueryTest {
2828
import org.apache.spark.sql.TestData._
2929

30-
// TODO: Add test cases for bitwise operations.
31-
3230
test("collect on column produced by a binary operator") {
3331
val df = Seq((1, 2, 3)).toDF("a", "b", "c")
3432
checkAnswer(df.select(df("a") + df("b")), Seq(Row(3)))
@@ -385,4 +383,35 @@ class ColumnExpressionSuite extends QueryTest {
385383
assert(row.getDouble(1) >= -4.0)
386384
}
387385
}
386+
387+
test("bitwiseAND") {
388+
checkAnswer(
389+
testData2.select($"a".bitwiseAND(75)),
390+
testData2.collect().toSeq.map(r => Row(r.getInt(0) & 75)))
391+
392+
checkAnswer(
393+
testData2.select($"a".bitwiseAND($"b").bitwiseAND(22)),
394+
testData2.collect().toSeq.map(r => Row(r.getInt(0) & r.getInt(1) & 22)))
395+
}
396+
397+
test("bitwiseOR") {
398+
checkAnswer(
399+
testData2.select($"a".bitwiseOR(170)),
400+
testData2.collect().toSeq.map(r => Row(r.getInt(0) | 170)))
401+
402+
checkAnswer(
403+
testData2.select($"a".bitwiseOR($"b").bitwiseOR(42)),
404+
testData2.collect().toSeq.map(r => Row(r.getInt(0) | r.getInt(1) | 42)))
405+
}
406+
407+
test("bitwiseXOR") {
408+
checkAnswer(
409+
testData2.select($"a".bitwiseXOR(112)),
410+
testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ 112)))
411+
412+
checkAnswer(
413+
testData2.select($"a".bitwiseXOR($"b").bitwiseXOR(39)),
414+
testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ r.getInt(1) ^ 39)))
415+
}
416+
388417
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql
1919

20+
import org.apache.spark.sql.TestData._
2021
import org.apache.spark.sql.functions._
2122
import org.apache.spark.sql.test.TestSQLContext.implicits._
2223
import org.apache.spark.sql.types._
@@ -81,4 +82,10 @@ class DataFrameFunctionsSuite extends QueryTest {
8182
struct(col("a") * 2)
8283
}
8384
}
85+
86+
test("bitwiseNOT") {
87+
checkAnswer(
88+
testData2.select(bitwiseNOT($"a")),
89+
testData2.collect().toSeq.map(r => Row(~r.getInt(0))))
90+
}
8491
}

0 commit comments

Comments
 (0)