Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,18 @@ def length(col):
return Column(sc._jvm.functions.length(_to_java_column(col)))


@ignore_unicode_prefix
@since(1.5)
def initcap(col):
"""Translate the first letter of each word to upper case in the sentence.

>>> sqlContext.createDataFrame([('a b',)], ['a']).select(initcap('a b').alias('v')).collect()
[Row(v='A B')]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.initcap(_to_java_column(col)))


@ignore_unicode_prefix
@since(1.5)
def format_number(col, d):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ object FunctionRegistry {
expression[Encode]("encode"),
expression[Decode]("decode"),
expression[FormatNumber]("format_number"),
expression[InitCap]("initcap"),
expression[Lower]("lcase"),
expression[Lower]("lower"),
expression[Length]("length"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,54 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
override def prettyName: String = "format_string"
}

/**
* Returns string, with the first letter of each word in uppercase,
* all other letters in lowercase. Words are delimited by whitespace.
*/
case class InitCap(child: Expression) extends UnaryExpression
with ImplicitCastInputTypes {
override def dataType: DataType = StringType

override def inputTypes: Seq[DataType] = Seq(StringType)

override def nullSafeEval(string: Any): Any = {
if (string.asInstanceOf[UTF8String].numBytes() == 0) {
return string
} else {
val sb = new StringBuffer()
sb.append(string)
sb.setCharAt(0, sb.charAt(0).toUpper)
for (i <- 1 until sb.length) {
if (sb.charAt(i - 1).equals(' ')) {
sb.setCharAt(i, sb.charAt(i).toUpper)
}
}
UTF8String.fromString(sb.toString)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should consider implement all of this on bytes directly. The conversion to Char isn't safe. I'm not sure, what happens if a character doesn't fit into Char. Using the assumption that a lower case and a upper case character have always the same number of bytes, we could easily use Array[Byte]. Even tough this isn't guaranteed by Unicode it seems to be true (maybe we could propose this to Unicode). But we can do this in a follow up PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My idea would be that we check if the next character fits Char. If yes we convert it to Char, call Character.toUpperCase(c) and change the result int the array. If we cannot convert it to Char, we "ignore" it and don't change it. But as Reynold mentioned, we can do this in a second step.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think conversion from byte to char will cause some extra efforts. We can do it directly on bytes(like code below). However, I cannot assure I have the total number-letter table of all Europe special letters.

  1. Get bytes value
  2. if (i == value of ñ) {
    i = value of Ñ;
    } // Ññ
  3. Return converted value.

}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (child) => {
val idx = ctx.freshName("idx")
val sb = ctx.freshName("sb")
val stringBuffer = classOf[StringBuffer].getName
val character = classOf[Character].getName
s"""
$stringBuffer $sb = new $stringBuffer();
$sb.append($child);
if($sb.length()>0) {
$sb.setCharAt(0,$character.toTitleCase($sb.charAt(0)));
for (int $idx = 1; $idx<$sb.length(); $idx++) {
if ($sb.charAt($idx - 1)==' ') {
$sb.setCharAt($idx,$character.toTitleCase($sb.charAt($idx)));
}
}
${ev.primitive} = UTF8String.fromString($sb.toString());
}
"""
})
}
}

/**
* Returns the string which repeat the given string value n times.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,18 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Decode(b, Literal.create(null, StringType)), null, create_row(null))
}

test("initcap unit test") {
checkEvaluation(InitCap(Literal(null)), null, create_row("s0"))
checkEvaluation(InitCap(Literal("a b")), "A B", create_row("s1"))
checkEvaluation(InitCap(Literal(" a")), " A", create_row("s2"))
checkEvaluation(InitCap(Literal("the test")), "The Test", create_row("s3"))
// scalastyle:off
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
checkEvaluation(InitCap(Literal("世界")), "世界", create_row("s4"))
// scalastyle:on
}


test("Levenshtein distance") {
checkEvaluation(Levenshtein(Literal.create(null, StringType), Literal("")), null)
checkEvaluation(Levenshtein(Literal(""), Literal.create(null, StringType)), null)
Expand Down
18 changes: 18 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1752,6 +1752,24 @@ object functions {
FormatString((lit(format) +: arguments).map(_.expr): _*)
}

/**
* Returns string, with the first letter of each word in uppercase,
* all other letters in lowercase. Words are delimited by whitespace.
*
* @group string_funcs
* @since 1.5.0
*/
def initcap(e: Column): Column = InitCap(e.expr)

/**
* Returns string, with the first letter of each word in uppercase,
* all other letters in lowercase. Words are delimited by whitespace.
*
* @group string_funcs
* @since 1.5.0
*/
def initcap(columnName: String): Column = initcap(Column(columnName))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @rxin made some clean up for the DF function, we'd better remove the columnName version of API.


/**
* Locate the position of the first occurrence of substr column in the given string.
* Returns null if either of the arguments are null.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,15 @@ class StringFunctionsSuite extends QueryTest {
}
}

test("initcap function") {
val df = Seq(("ab", "a B")).toDF("l", "r")
checkAnswer(
df.select(initcap($"l"), initcap("r")), Row("Ab", "A B"))

checkAnswer(
df.selectExpr("InitCap(l)", "InitCap(r)"), Row("Ab", "A B"))
}

test("number format function") {
val tuple =
("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short],
Expand Down