@@ -2124,6 +2124,10 @@ def head(self, n=None):
21242124 return rs [0 ] if rs else None
21252125 return self .take (n )
21262126
2127+ def first (self ):
2128+ """ Return the first row. """
2129+ return self .head ()
2130+
21272131 def tail (self ):
21282132 raise NotImplemented
21292133
@@ -2159,7 +2163,7 @@ def select(self, *cols):
21592163 else :
21602164 cols = [c ._jc for c in cols ]
21612165 jcols = ListConverter ().convert (cols , self ._sc ._gateway ._gateway_client )
2162- jdf = self ._jdf .select (self ._jdf . toColumnArray (jcols ))
2166+ jdf = self ._jdf .select (self .sql_ctx . _sc . _jvm . Dsl . toColumns (jcols ))
21632167 return DataFrame (jdf , self .sql_ctx )
21642168
21652169 def filter (self , condition ):
@@ -2189,7 +2193,7 @@ def groupBy(self, *cols):
21892193 else :
21902194 cols = [c ._jc for c in cols ]
21912195 jcols = ListConverter ().convert (cols , self ._sc ._gateway ._gateway_client )
2192- jdf = self ._jdf .groupBy (self ._jdf . toColumnArray (jcols ))
2196+ jdf = self ._jdf .groupBy (self .sql_ctx . _sc . _jvm . Dsl . toColumns (jcols ))
21932197 return GroupedDataFrame (jdf , self .sql_ctx )
21942198
21952199 def agg (self , * exprs ):
@@ -2278,14 +2282,17 @@ def agg(self, *exprs):
22782282 :param exprs: list or aggregate columns or a map from column
22792283 name to agregate methods.
22802284 """
2285+ assert exprs , "exprs should not be empty"
22812286 if len (exprs ) == 1 and isinstance (exprs [0 ], dict ):
22822287 jmap = MapConverter ().convert (exprs [0 ],
22832288 self .sql_ctx ._sc ._gateway ._gateway_client )
22842289 jdf = self ._jdf .agg (jmap )
22852290 else :
22862291 # Columns
2287- assert all (isinstance (c , Column ) for c in exprs ), "all exprs should be Columns"
2288- jdf = self ._jdf .agg (* exprs )
2292+ assert all (isinstance (c , Column ) for c in exprs ), "all exprs should be Column"
2293+ jcols = ListConverter ().convert ([c ._jc for c in exprs [1 :]],
2294+ self .sql_ctx ._sc ._gateway ._gateway_client )
2295+ jdf = self ._jdf .agg (exprs [0 ]._jc , self .sql_ctx ._sc ._jvm .Dsl .toColumns (jcols ))
22892296 return DataFrame (jdf , self .sql_ctx )
22902297
22912298 @dfapi
@@ -2347,7 +2354,7 @@ def _create_column_from_literal(literal):
23472354
23482355def _create_column_from_name (name ):
23492356 sc = SparkContext ._active_spark_context
2350- return sc ._jvm .Column (name )
2357+ return sc ._jvm .IncomputableColumn (name )
23512358
23522359
23532360def _scalaMethod (name ):
@@ -2371,7 +2378,7 @@ def _(self):
23712378 return _
23722379
23732380
2374- def _bin_op (name , pass_literal_through = False ):
2381+ def _bin_op (name , pass_literal_through = True ):
23752382 """ Create a method for given binary operator
23762383
23772384 Keyword arguments:
@@ -2465,18 +2472,17 @@ def __init__(self, jc, jdf=None, sql_ctx=None):
24652472 # __getattr__ = _bin_op("getField")
24662473
24672474 # string methods
2468- rlike = _bin_op ("rlike" , pass_literal_through = True )
2469- like = _bin_op ("like" , pass_literal_through = True )
2470- startswith = _bin_op ("startsWith" , pass_literal_through = True )
2471- endswith = _bin_op ("endsWith" , pass_literal_through = True )
2475+ rlike = _bin_op ("rlike" )
2476+ like = _bin_op ("like" )
2477+ startswith = _bin_op ("startsWith" )
2478+ endswith = _bin_op ("endsWith" )
24722479 upper = _unary_op ("upper" )
24732480 lower = _unary_op ("lower" )
24742481
24752482 def substr (self , startPos , pos ):
24762483 if type (startPos ) != type (pos ):
24772484 raise TypeError ("Can not mix the type" )
24782485 if isinstance (startPos , (int , long )):
2479-
24802486 jc = self ._jc .substr (startPos , pos )
24812487 elif isinstance (startPos , Column ):
24822488 jc = self ._jc .substr (startPos ._jc , pos ._jc )
@@ -2507,30 +2513,53 @@ def cast(self, dataType):
25072513 return Column (self ._jc .cast (jdt ), self ._jdf , self .sql_ctx )
25082514
25092515
2516+ def _to_java_column (col ):
2517+ if isinstance (col , Column ):
2518+ jcol = col ._jc
2519+ else :
2520+ jcol = _create_column_from_name (col )
2521+ return jcol
2522+
2523+
25102524def _aggregate_func (name ):
25112525 """ Create a function for aggregator by name"""
25122526 def _ (col ):
25132527 sc = SparkContext ._active_spark_context
2514- if isinstance (col , Column ):
2515- jcol = col ._jc
2516- else :
2517- jcol = _create_column_from_name (col )
2518- jc = getattr (sc ._jvm .org .apache .spark .sql .Dsl , name )(jcol )
2528+ jc = getattr (sc ._jvm .Dsl , name )(_to_java_column (col ))
25192529 return Column (jc )
2530+
25202531 return staticmethod (_ )
25212532
25222533
25232534class Aggregator (object ):
25242535 """
25252536 A collections of builtin aggregators
25262537 """
2527- max = _aggregate_func ("max" )
2528- min = _aggregate_func ("min" )
2529- avg = mean = _aggregate_func ("mean" )
2530- sum = _aggregate_func ("sum" )
2531- first = _aggregate_func ("first" )
2532- last = _aggregate_func ("last" )
2533- count = _aggregate_func ("count" )
2538+ AGGS = [
2539+ 'lit' , 'col' , 'column' , 'upper' , 'lower' , 'sqrt' , 'abs' ,
2540+ 'min' , 'max' , 'first' , 'last' , 'count' , 'avg' , 'mean' , 'sum' , 'sumDistinct' ,
2541+ ]
2542+ for _name in AGGS :
2543+ locals ()[_name ] = _aggregate_func (_name )
2544+ del _name
2545+
2546+ @staticmethod
2547+ def countDistinct (col , * cols ):
2548+ sc = SparkContext ._active_spark_context
2549+ jcols = ListConverter ().convert ([_to_java_column (c ) for c in cols ],
2550+ sc ._gateway ._gateway_client )
2551+ jc = sc ._jvm .Dsl .countDistinct (_to_java_column (col ),
2552+ sc ._jvm .Dsl .toColumns (jcols ))
2553+ return Column (jc )
2554+
2555+ @staticmethod
2556+ def approxCountDistinct (col , rsd = None ):
2557+ sc = SparkContext ._active_spark_context
2558+ if rsd is None :
2559+ jc = sc ._jvm .Dsl .approxCountDistinct (_to_java_column (col ))
2560+ else :
2561+ jc = sc ._jvm .Dsl .approxCountDistinct (_to_java_column (col ), rsd )
2562+ return Column (jc )
25342563
25352564
25362565def _test ():
0 commit comments