@@ -2124,6 +2124,10 @@ def head(self, n=None):
21242124            return  rs [0 ] if  rs  else  None 
21252125        return  self .take (n )
21262126
2127+     def  first (self ):
2128+         """ Return the first row. """ 
2129+         return  self .head ()
2130+ 
21272131    def  tail (self ):
21282132        raise  NotImplemented 
21292133
@@ -2159,7 +2163,7 @@ def select(self, *cols):
21592163        else :
21602164            cols  =  [c ._jc  for  c  in  cols ]
21612165        jcols  =  ListConverter ().convert (cols , self ._sc ._gateway ._gateway_client )
2162-         jdf  =  self ._jdf .select (self ._jdf . toColumnArray (jcols ))
2166+         jdf  =  self ._jdf .select (self .sql_ctx . _sc . _jvm . Dsl . toColumns (jcols ))
21632167        return  DataFrame (jdf , self .sql_ctx )
21642168
21652169    def  filter (self , condition ):
@@ -2189,7 +2193,7 @@ def groupBy(self, *cols):
21892193        else :
21902194            cols  =  [c ._jc  for  c  in  cols ]
21912195        jcols  =  ListConverter ().convert (cols , self ._sc ._gateway ._gateway_client )
2192-         jdf  =  self ._jdf .groupBy (self ._jdf . toColumnArray (jcols ))
2196+         jdf  =  self ._jdf .groupBy (self .sql_ctx . _sc . _jvm . Dsl . toColumns (jcols ))
21932197        return  GroupedDataFrame (jdf , self .sql_ctx )
21942198
21952199    def  agg (self , * exprs ):
@@ -2278,14 +2282,17 @@ def agg(self, *exprs):
22782282        :param exprs: list or aggregate columns or a map from column 
22792283                      name to agregate methods. 
22802284        """ 
2285+         assert  exprs , "exprs should not be empty" 
22812286        if  len (exprs ) ==  1  and  isinstance (exprs [0 ], dict ):
22822287            jmap  =  MapConverter ().convert (exprs [0 ],
22832288                                          self .sql_ctx ._sc ._gateway ._gateway_client )
22842289            jdf  =  self ._jdf .agg (jmap )
22852290        else :
22862291            # Columns 
2287-             assert  all (isinstance (c , Column ) for  c  in  exprs ), "all exprs should be Columns" 
2288-             jdf  =  self ._jdf .agg (* exprs )
2292+             assert  all (isinstance (c , Column ) for  c  in  exprs ), "all exprs should be Column" 
2293+             jcols  =  ListConverter ().convert ([c ._jc  for  c  in  exprs [1 :]],
2294+                                             self .sql_ctx ._sc ._gateway ._gateway_client )
2295+             jdf  =  self ._jdf .agg (exprs [0 ]._jc , self .sql_ctx ._sc ._jvm .Dsl .toColumns (jcols ))
22892296        return  DataFrame (jdf , self .sql_ctx )
22902297
22912298    @dfapi  
@@ -2347,7 +2354,7 @@ def _create_column_from_literal(literal):
23472354
23482355def  _create_column_from_name (name ):
23492356    sc  =  SparkContext ._active_spark_context 
2350-     return  sc ._jvm .Column (name )
2357+     return  sc ._jvm .IncomputableColumn (name )
23512358
23522359
23532360def  _scalaMethod (name ):
@@ -2371,7 +2378,7 @@ def _(self):
23712378    return  _ 
23722379
23732380
2374- def  _bin_op (name , pass_literal_through = False ):
2381+ def  _bin_op (name , pass_literal_through = True ):
23752382    """ Create a method for given binary operator 
23762383
23772384    Keyword arguments: 
@@ -2465,18 +2472,17 @@ def __init__(self, jc, jdf=None, sql_ctx=None):
24652472    # __getattr__ = _bin_op("getField") 
24662473
24672474    # string methods 
2468-     rlike  =  _bin_op ("rlike" ,  pass_literal_through = True )
2469-     like  =  _bin_op ("like" ,  pass_literal_through = True )
2470-     startswith  =  _bin_op ("startsWith" ,  pass_literal_through = True )
2471-     endswith  =  _bin_op ("endsWith" ,  pass_literal_through = True )
2475+     rlike  =  _bin_op ("rlike" )
2476+     like  =  _bin_op ("like" )
2477+     startswith  =  _bin_op ("startsWith" )
2478+     endswith  =  _bin_op ("endsWith" )
24722479    upper  =  _unary_op ("upper" )
24732480    lower  =  _unary_op ("lower" )
24742481
24752482    def  substr (self , startPos , pos ):
24762483        if  type (startPos ) !=  type (pos ):
24772484            raise  TypeError ("Can not mix the type" )
24782485        if  isinstance (startPos , (int , long )):
2479- 
24802486            jc  =  self ._jc .substr (startPos , pos )
24812487        elif  isinstance (startPos , Column ):
24822488            jc  =  self ._jc .substr (startPos ._jc , pos ._jc )
@@ -2507,30 +2513,53 @@ def cast(self, dataType):
25072513        return  Column (self ._jc .cast (jdt ), self ._jdf , self .sql_ctx )
25082514
25092515
2516+ def  _to_java_column (col ):
2517+     if  isinstance (col , Column ):
2518+         jcol  =  col ._jc 
2519+     else :
2520+         jcol  =  _create_column_from_name (col )
2521+     return  jcol 
2522+ 
2523+ 
25102524def  _aggregate_func (name ):
25112525    """ Create a function for aggregator by name""" 
25122526    def  _ (col ):
25132527        sc  =  SparkContext ._active_spark_context 
2514-         if  isinstance (col , Column ):
2515-             jcol  =  col ._jc 
2516-         else :
2517-             jcol  =  _create_column_from_name (col )
2518-         jc  =  getattr (sc ._jvm .org .apache .spark .sql .Dsl , name )(jcol )
2528+         jc  =  getattr (sc ._jvm .Dsl , name )(_to_java_column (col ))
25192529        return  Column (jc )
2530+ 
25202531    return  staticmethod (_ )
25212532
25222533
25232534class  Aggregator (object ):
25242535    """ 
25252536    A collections of builtin aggregators 
25262537    """ 
2527-     max  =  _aggregate_func ("max" )
2528-     min  =  _aggregate_func ("min" )
2529-     avg  =  mean  =  _aggregate_func ("mean" )
2530-     sum  =  _aggregate_func ("sum" )
2531-     first  =  _aggregate_func ("first" )
2532-     last  =  _aggregate_func ("last" )
2533-     count  =  _aggregate_func ("count" )
2538+     AGGS  =  [
2539+         'lit' , 'col' , 'column' , 'upper' , 'lower' , 'sqrt' , 'abs' ,
2540+         'min' , 'max' , 'first' , 'last' , 'count' , 'avg' , 'mean' , 'sum' , 'sumDistinct' ,
2541+     ]
2542+     for  _name  in  AGGS :
2543+         locals ()[_name ] =  _aggregate_func (_name )
2544+     del  _name 
2545+ 
2546+     @staticmethod  
2547+     def  countDistinct (col , * cols ):
2548+         sc  =  SparkContext ._active_spark_context 
2549+         jcols  =  ListConverter ().convert ([_to_java_column (c ) for  c  in  cols ],
2550+                                         sc ._gateway ._gateway_client )
2551+         jc  =  sc ._jvm .Dsl .countDistinct (_to_java_column (col ),
2552+                                        sc ._jvm .Dsl .toColumns (jcols ))
2553+         return  Column (jc )
2554+ 
2555+     @staticmethod  
2556+     def  approxCountDistinct (col , rsd = None ):
2557+         sc  =  SparkContext ._active_spark_context 
2558+         if  rsd  is  None :
2559+             jc  =  sc ._jvm .Dsl .approxCountDistinct (_to_java_column (col ))
2560+         else :
2561+             jc  =  sc ._jvm .Dsl .approxCountDistinct (_to_java_column (col ), rsd )
2562+         return  Column (jc )
25342563
25352564
25362565def  _test ():
0 commit comments