Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
/**
* :: AlphaComponent ::
* A tokenizer that converts the input string to lowercase and then splits it by white spaces.
*
* @see [[RegexTokenizer]]
*/
@AlphaComponent
class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[String], Tokenizer] {
Expand All @@ -45,9 +47,9 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S

/**
* :: AlphaComponent ::
* A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default)
* or using it to split the text (set matching to false). Optional parameters also allow filtering
* tokens using a minimal length.
* A regex based tokenizer that extracts tokens either by using the provided regex pattern to split
* the text (default) or repeatedly matching the regex (if `gaps` is true).
* Optional parameters also allow filtering tokens using a minimal length.
* It returns an array of strings that can be empty.
*/
@AlphaComponent
Expand All @@ -71,8 +73,8 @@ class RegexTokenizer(override val uid: String)
def getMinTokenLength: Int = $(minTokenLength)

/**
* Indicates whether regex splits on gaps (true) or matching tokens (false).
* Default: false
* Indicates whether regex splits on gaps (true) or matches tokens (false).
* Default: true
* @group param
*/
val gaps: BooleanParam = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens")
Expand All @@ -84,8 +86,8 @@ class RegexTokenizer(override val uid: String)
def getGaps: Boolean = $(gaps)

/**
* Regex pattern used by tokenizer.
* Default: `"\\p{L}+|[^\\p{L}\\s]+"`
* Regex pattern used to match delimiters if [[gaps]] is true or tokens if [[gaps]] is false.
* Default: `"\\s+"`
* @group param
*/
val pattern: Param[String] = new Param(this, "pattern", "regex pattern used for tokenizing")
Expand All @@ -96,7 +98,7 @@ class RegexTokenizer(override val uid: String)
/** @group getParam */
def getPattern: String = $(pattern)

setDefault(minTokenLength -> 1, gaps -> false, pattern -> "\\p{L}+|[^\\p{L}\\s]+")
setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+")

override protected def createTransformFunc: String => Seq[String] = { str =>
val re = $(pattern).r
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,35 +29,34 @@ case class TokenizerTestData(rawText: String, wantedTokens: Array[String])

class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext {
import org.apache.spark.ml.feature.RegexTokenizerSuite._

test("RegexTokenizer") {
val tokenizer = new RegexTokenizer()
val tokenizer0 = new RegexTokenizer()
.setGaps(false)
.setPattern("\\w+|\\p{Punct}")
.setInputCol("rawText")
.setOutputCol("tokens")

val dataset0 = sqlContext.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization", ".")),
TokenizerTestData("Te,st. punct", Array("Te", ",", "st", ".", "punct"))
))
testRegexTokenizer(tokenizer, dataset0)
testRegexTokenizer(tokenizer0, dataset0)

val dataset1 = sqlContext.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization")),
TokenizerTestData("Te,st. punct", Array("punct"))
))
tokenizer0.setMinTokenLength(3)
testRegexTokenizer(tokenizer0, dataset1)

tokenizer.setMinTokenLength(3)
testRegexTokenizer(tokenizer, dataset1)

tokenizer
.setPattern("\\s")
.setGaps(true)
.setMinTokenLength(0)
val tokenizer2 = new RegexTokenizer()
.setInputCol("rawText")
.setOutputCol("tokens")
val dataset2 = sqlContext.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization.")),
TokenizerTestData("Te,st. punct", Array("Te,st.", "", "punct"))
TokenizerTestData("Te,st. punct", Array("Te,st.", "punct"))
))
testRegexTokenizer(tokenizer, dataset2)
testRegexTokenizer(tokenizer2, dataset2)
}
}

Expand All @@ -67,9 +66,8 @@ object RegexTokenizerSuite extends FunSuite {
t.transform(dataset)
.select("tokens", "wantedTokens")
.collect()
.foreach {
case Row(tokens, wantedTokens) =>
assert(tokens === wantedTokens)
}
.foreach { case Row(tokens, wantedTokens) =>
assert(tokens === wantedTokens)
}
}
}
40 changes: 19 additions & 21 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,23 +446,25 @@ def getDegree(self):
@ignore_unicode_prefix
class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol):
"""
A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default)
or using it to split the text (set matching to false). Optional parameters also allow filtering
tokens using a minimal length.
A regex based tokenizer that extracts tokens either by using the
provided regex pattern (in Java dialect) to split the text
(default) or repeatedly matching the regex (if gaps is true).
Optional parameters also allow filtering tokens using a minimal
length.
It returns an array of strings that can be empty.

>>> df = sqlContext.createDataFrame([("a b c",)], ["text"])
>>> df = sqlContext.createDataFrame([("a b c",)], ["text"])
>>> reTokenizer = RegexTokenizer(inputCol="text", outputCol="words")
>>> reTokenizer.transform(df).head()
Row(text=u'a b c', words=[u'a', u'b', u'c'])
Row(text=u'a b c', words=[u'a', u'b', u'c'])
>>> # Change a parameter.
>>> reTokenizer.setParams(outputCol="tokens").transform(df).head()
Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
>>> # Temporarily modify a parameter.
>>> reTokenizer.transform(df, {reTokenizer.outputCol: "words"}).head()
Row(text=u'a b c', words=[u'a', u'b', u'c'])
Row(text=u'a b c', words=[u'a', u'b', u'c'])
>>> reTokenizer.transform(df).head()
Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
>>> # Must use keyword arguments to specify params.
>>> reTokenizer.setParams("text")
Traceback (most recent call last):
Expand All @@ -472,31 +474,27 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol):

# a placeholder to make it appear in the generated doc
minTokenLength = Param(Params._dummy(), "minTokenLength", "minimum token length (>= 0)")
gaps = Param(Params._dummy(), "gaps", "Set regex to match gaps or tokens")
pattern = Param(Params._dummy(), "pattern", "regex pattern used for tokenizing")
gaps = Param(Params._dummy(), "gaps", "whether regex splits on gaps (True) or matches tokens")
pattern = Param(Params._dummy(), "pattern", "regex pattern (Java dialect) used for tokenizing")

@keyword_only
def __init__(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+",
inputCol=None, outputCol=None):
def __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None):
"""
__init__(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+", \
inputCol=None, outputCol=None)
__init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None)
"""
super(RegexTokenizer, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RegexTokenizer", self.uid)
self.minTokenLength = Param(self, "minTokenLength", "minimum token length (>= 0)")
self.gaps = Param(self, "gaps", "Set regex to match gaps or tokens")
self.pattern = Param(self, "pattern", "regex pattern used for tokenizing")
self._setDefault(minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+")
self.gaps = Param(self, "gaps", "whether regex splits on gaps (True) or matches tokens")
self.pattern = Param(self, "pattern", "regex pattern (Java dialect) used for tokenizing")
self._setDefault(minTokenLength=1, gaps=True, pattern="\\s+")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

@keyword_only
def setParams(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+",
inputCol=None, outputCol=None):
def setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None):
"""
setParams(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+", \
inputCol="input", outputCol="output")
setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None)
Sets params for this RegexTokenizer.
"""
kwargs = self.setParams._input_kwargs
Expand Down