From 7ce416ba8222b578dffffa34e574192e8821a5fc Mon Sep 17 00:00:00 2001 From: HuJiayin Date: Wed, 22 Jul 2015 09:43:43 +0800 Subject: [PATCH 1/7] support initcap rebase code --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/stringOperations.scala | 28 +++++++++++++++++++ .../org/apache/spark/sql/functions.scala | 18 ++++++++++++ .../spark/sql/StringFunctionsSuite.scala | 9 ++++++ 4 files changed, 56 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e3d8d2adf2135..39f5d75dbf040 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -158,6 +158,7 @@ object FunctionRegistry { expression[Encode]("encode"), expression[Decode]("decode"), expression[FormatNumber]("format_number"), + expression[InitCap]("initcap"), expression[Lower]("lcase"), expression[Lower]("lower"), expression[Length]("length"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 1f18a6e9ff8a5..4aa9a4beeeba5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -594,6 +594,34 @@ case class StringFormat(children: Expression*) extends Expression with ImplicitC override def prettyName: String = "printf" } +/** + * Returns string, with the first letter of each word in uppercase, + * all other letters in lowercase. Words are delimited by whitespace. + */ +case class InitCap(child: Expression) extends UnaryExpression + with ExpectsInputTypes with CodegenFallback { + override def dataType: DataType = StringType + + override def inputTypes: Seq[DataType] = Seq(StringType) + + override def nullSafeEval(string: Any): Any = { + if (string.asInstanceOf[UTF8String].getBytes.length == 0) { + return string + } + else { + val sb = new StringBuffer() + sb.append(string) + sb.setCharAt(0, sb.charAt(0).toUpper) + for (i <- 1 until sb.length) { + if (sb.charAt(i - 1).equals(' ')) { + sb.setCharAt(i, sb.charAt(i).toUpper) + } + } + UTF8String.fromString(sb.toString) + } + } +} + /** * Returns the string which repeat the given string value n times. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e5ff8ae7e3179..3723a9e7d04d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1764,6 +1764,24 @@ object functions { StringFormat(lit(format).expr +: arguNames.map(Column(_).expr): _*) } + /** + * Returns string, with the first letter of each word in uppercase, + * all other letters in lowercase. Words are delimited by whitespace. + * + * @group string_funcs + * @since 1.5.0 + */ + def initcap(e: Column): Column = InitCap(e.expr) + + /** + * Returns string, with the first letter of each word in uppercase, + * all other letters in lowercase. Words are delimited by whitespace. + * + * @group string_funcs + * @since 1.5.0 + */ + def initcap(columnName: String): Column = initcap(Column(columnName)) + /** * Locate the position of the first occurrence of substr column in the given string. * Returns null if either of the arguments are null. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 3702e73b4e74f..edf1113bb9b47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -241,6 +241,15 @@ class StringFunctionsSuite extends QueryTest { } } + test("initcap function") { + val df = Seq(("ab", "a B")).toDF("l", "r") + checkAnswer( + df.select(initcap($"l"), initcap("r")), Row("Ab", "A B")) + + checkAnswer( + df.selectExpr("InitCap(l)", "InitCap(r)"), Row("Ab", "A B")) + } + test("number format function") { val tuple = ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short], From c79482d42d70eceb1829f7798ed54cdbbcfd6556 Mon Sep 17 00:00:00 2001 From: HuJiayin Date: Thu, 23 Jul 2015 09:31:30 +0800 Subject: [PATCH 2/7] support soundex --- .../sql/catalyst/expressions/stringOperations.scala | 5 ++--- .../expressions/StringExpressionsSuite.scala | 13 +++++++++++++ .../main/scala/org/apache/spark/sql/functions.scala | 8 -------- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 4aa9a4beeeba5..6bcbef27a975d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -605,10 +605,9 @@ case class InitCap(child: Expression) extends UnaryExpression override def inputTypes: Seq[DataType] = Seq(StringType) override def nullSafeEval(string: Any): Any = { - if (string.asInstanceOf[UTF8String].getBytes.length == 0) { + if (string.asInstanceOf[UTF8String].numBytes() == 0) { return string - } - else { + } else { val sb = new StringBuffer() sb.append(string) sb.setCharAt(0, sb.charAt(0).toUpper) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 3c2d88731beb4..d5657de4f1548 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -317,6 +317,19 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Decode(b, Literal.create(null, StringType)), null, create_row(null)) } + test("initcap unit test") { + checkEvaluation(InitCap(Literal(null)), null, create_row("s1")) + checkEvaluation(InitCap(Literal("")), "", create_row("s2")) + checkEvaluation(InitCap(Literal("a b")), "A B", create_row("s3")) + checkEvaluation(InitCap(Literal(" a")), " A", create_row("s4")) + checkEvaluation(InitCap(Literal("the test")), "The Test", create_row("s5")) + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + checkEvaluation(InitCap(Literal("世界")), "世界", create_row("s6")) + // scalastyle:on + } + + test("Levenshtein distance") { checkEvaluation(Levenshtein(Literal.create(null, StringType), Literal("")), null) checkEvaluation(Levenshtein(Literal(""), Literal.create(null, StringType)), null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3723a9e7d04d4..1d2466a3c6390 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1773,14 +1773,6 @@ object functions { */ def initcap(e: Column): Column = InitCap(e.expr) - /** - * Returns string, with the first letter of each word in uppercase, - * all other letters in lowercase. Words are delimited by whitespace. - * - * @group string_funcs - * @since 1.5.0 - */ - def initcap(columnName: String): Column = initcap(Column(columnName)) /** * Locate the position of the first occurrence of substr column in the given string. From 6a0b958980c35b0e8da526ec9cf670d082b334e8 Mon Sep 17 00:00:00 2001 From: HuJiayin Date: Thu, 23 Jul 2015 11:19:31 +0800 Subject: [PATCH 3/7] add column --- .../src/main/scala/org/apache/spark/sql/functions.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 1d2466a3c6390..3723a9e7d04d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1773,6 +1773,14 @@ object functions { */ def initcap(e: Column): Column = InitCap(e.expr) + /** + * Returns string, with the first letter of each word in uppercase, + * all other letters in lowercase. Words are delimited by whitespace. + * + * @group string_funcs + * @since 1.5.0 + */ + def initcap(columnName: String): Column = initcap(Column(columnName)) /** * Locate the position of the first occurrence of substr column in the given string. From 1f5a0efc970f2688254ab846f0ac5d4b38a8a377 Mon Sep 17 00:00:00 2001 From: HuJiayin Date: Fri, 24 Jul 2015 16:30:49 +0800 Subject: [PATCH 4/7] add codegen --- .../expressions/stringOperations.scala | 23 ++++++++++++++++++- .../expressions/StringExpressionsSuite.scala | 11 ++++----- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 631311acc5818..f42aadc09b2d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -596,7 +596,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC * all other letters in lowercase. Words are delimited by whitespace. */ case class InitCap(child: Expression) extends UnaryExpression - with ExpectsInputTypes with CodegenFallback { + with ImplicitCastInputTypes { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -616,6 +616,27 @@ case class InitCap(child: Expression) extends UnaryExpression UTF8String.fromString(sb.toString) } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (child) => { + val idx = ctx.freshName("idx") + val sb = ctx.freshName("sb") + val stringBuffer = classOf[StringBuffer].getName + val character = classOf[Character].getName + s""" + $stringBuffer $sb = new $stringBuffer(); + $sb.append($child); + if($sb.length()>0) { + $sb.setCharAt(0,$character.toTitleCase($sb.charAt(0))); + for (int $idx = 1; $idx<$sb.length(); $idx++) { + if ($sb.charAt($idx - 1)==' ') { + $sb.setCharAt($idx,$character.toTitleCase($sb.charAt($idx))); + } + } + ${ev.primitive} = UTF8String.fromString($sb.toString()); + } + """ + }) + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 805aa5616611e..5179ec77dc0a0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -318,14 +318,13 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("initcap unit test") { - checkEvaluation(InitCap(Literal(null)), null, create_row("s1")) - checkEvaluation(InitCap(Literal("")), "", create_row("s2")) - checkEvaluation(InitCap(Literal("a b")), "A B", create_row("s3")) - checkEvaluation(InitCap(Literal(" a")), " A", create_row("s4")) - checkEvaluation(InitCap(Literal("the test")), "The Test", create_row("s5")) + checkEvaluation(InitCap(Literal(null)), null, create_row("s0")) + checkEvaluation(InitCap(Literal("a b")), "A B", create_row("s1")) + checkEvaluation(InitCap(Literal(" a")), " A", create_row("s2")) + checkEvaluation(InitCap(Literal("the test")), "The Test", create_row("s3")) // scalastyle:off // non ascii characters are not allowed in the code, so we disable the scalastyle here. - checkEvaluation(InitCap(Literal("世界")), "世界", create_row("s6")) + checkEvaluation(InitCap(Literal("世界")), "世界", create_row("s4")) // scalastyle:on } From b616c0e4223379cae162ec24c28f0a49e141acbf Mon Sep 17 00:00:00 2001 From: HuJiayin Date: Fri, 31 Jul 2015 14:44:12 +0800 Subject: [PATCH 5/7] add python api --- python/pyspark/sql/functions.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 719e623a1a11f..6f1ca0b286c5a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -553,6 +553,18 @@ def length(col): return Column(sc._jvm.functions.length(_to_java_column(col))) +@ignore_unicode_prefix +@since(1.5) +def initcap(col): + """Translate the first letter of each word to upper case in the sentence. + + >>> sqlContext.createDataFrame([('the test',)], ['a']).select(initcap('the test').alias('v')).collect() + [Row(v='The Test')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.initcap(_to_java_column(col))) + + @ignore_unicode_prefix @since(1.5) def format_number(col, d): From 2cd43e5ed2e2676755c304ab051826e9ea44d4ff Mon Sep 17 00:00:00 2001 From: HuJiayin Date: Fri, 31 Jul 2015 16:29:35 +0800 Subject: [PATCH 6/7] fix python style check --- python/pyspark/sql/functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6f1ca0b286c5a..b23cc5beadab6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -558,7 +558,8 @@ def length(col): def initcap(col): """Translate the first letter of each word to upper case in the sentence. - >>> sqlContext.createDataFrame([('the test',)], ['a']).select(initcap('the test').alias('v')).collect() + >>> sqlContext.createDataFrame([('the test',)], ['a']).select(initcap('the test')\ + .alias('v')).collect() [Row(v='The Test')] """ sc = SparkContext._active_spark_context From 8b2506ad605a035e41206d83eca064ae5a447ac7 Mon Sep 17 00:00:00 2001 From: HuJiayin Date: Fri, 31 Jul 2015 20:52:25 +0800 Subject: [PATCH 7/7] Update functions.py --- python/pyspark/sql/functions.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b23cc5beadab6..aeca8064b61e7 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -557,10 +557,9 @@ def length(col): @since(1.5) def initcap(col): """Translate the first letter of each word to upper case in the sentence. - - >>> sqlContext.createDataFrame([('the test',)], ['a']).select(initcap('the test')\ - .alias('v')).collect() - [Row(v='The Test')] + + >>> sqlContext.createDataFrame([('a b',)], ['a']).select(initcap('a b').alias('v')).collect() + [Row(v='A B')] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.initcap(_to_java_column(col)))