@@ -772,9 +772,10 @@ def grouper(x):
772772class CSVFormatter (object ):
773773
774774 def __init__ (self , obj , path_or_buf , sep = "," , na_rep = '' , float_format = None ,
775- cols = None , header = True , index = True , index_label = None ,
776- mode = 'w' , nanRep = None , encoding = None , quoting = None ,
777- line_terminator = '\n ' , chunksize = None , engine = None ):
775+ cols = None , header = True , index = True , index_label = None ,
776+ mode = 'w' , nanRep = None , encoding = None , quoting = None ,
777+ line_terminator = '\n ' , chunksize = None , engine = None ,
778+ tupleize_cols = True ):
778779
779780 self .engine = engine # remove for 0.12
780781
@@ -803,6 +804,15 @@ def __init__(self, obj, path_or_buf, sep=",", na_rep='', float_format=None,
803804 msg = "columns.is_unique == False not supported with engine='python'"
804805 raise NotImplementedError (msg )
805806
807+ self .tupleize_cols = tupleize_cols
808+ self .has_mi_columns = isinstance (obj .columns , MultiIndex
809+ ) and not self .tupleize_cols
810+
811+ # validate mi options
812+ if self .has_mi_columns :
813+ if cols is not None :
814+ raise Exception ("cannot specify cols with a multi_index on the columns" )
815+
806816 if cols is not None :
807817 if isinstance (cols ,Index ):
808818 cols = cols .to_native_types (na_rep = na_rep ,float_format = float_format )
@@ -958,48 +968,82 @@ def _save_header(self):
958968 obj = self .obj
959969 index_label = self .index_label
960970 cols = self .cols
971+ has_mi_columns = self .has_mi_columns
961972 header = self .header
973+ encoded_labels = []
962974
963975 has_aliases = isinstance (header , (tuple , list , np .ndarray ))
964- if has_aliases or self .header :
965- if self .index :
966- # should write something for index label
967- if index_label is not False :
968- if index_label is None :
969- if isinstance (obj .index , MultiIndex ):
970- index_label = []
971- for i , name in enumerate (obj .index .names ):
972- if name is None :
973- name = ''
974- index_label .append (name )
976+ if not (has_aliases or self .header ):
977+ return
978+
979+ if self .index :
980+ # should write something for index label
981+ if index_label is not False :
982+ if index_label is None :
983+ if isinstance (obj .index , MultiIndex ):
984+ index_label = []
985+ for i , name in enumerate (obj .index .names ):
986+ if name is None :
987+ name = ''
988+ index_label .append (name )
989+ else :
990+ index_label = obj .index .name
991+ if index_label is None :
992+ index_label = ['' ]
975993 else :
976- index_label = obj .index .name
977- if index_label is None :
978- index_label = ['' ]
979- else :
980- index_label = [index_label ]
981- elif not isinstance (index_label , (list , tuple , np .ndarray )):
982- # given a string for a DF with Index
983- index_label = [index_label ]
994+ index_label = [index_label ]
995+ elif not isinstance (index_label , (list , tuple , np .ndarray )):
996+ # given a string for a DF with Index
997+ index_label = [index_label ]
984998
985- encoded_labels = list (index_label )
986- else :
987- encoded_labels = []
999+ encoded_labels = list (index_label )
1000+ else :
1001+ encoded_labels = []
9881002
989- if has_aliases :
990- if len (header ) != len (cols ):
991- raise ValueError (('Writing %d cols but got %d aliases'
992- % (len (cols ), len (header ))))
993- else :
994- write_cols = header
1003+ if has_aliases :
1004+ if len (header ) != len (cols ):
1005+ raise ValueError (('Writing %d cols but got %d aliases'
1006+ % (len (cols ), len (header ))))
9951007 else :
996- write_cols = cols
997- encoded_cols = list (write_cols )
998-
999- writer .writerow (encoded_labels + encoded_cols )
1008+ write_cols = header
10001009 else :
1001- encoded_cols = list (cols )
1002- writer .writerow (encoded_cols )
1010+ write_cols = cols
1011+
1012+ if not has_mi_columns :
1013+ encoded_labels += list (write_cols )
1014+
1015+ else :
1016+
1017+ if not has_mi_columns :
1018+ encoded_labels += list (cols )
1019+
1020+ # write out the mi
1021+ if has_mi_columns :
1022+ columns = obj .columns
1023+
1024+ # write out the names for each level, then ALL of the values for each level
1025+ for i in range (columns .nlevels ):
1026+
1027+ # we need at least 1 index column to write our col names
1028+ col_line = []
1029+ if self .index :
1030+
1031+ # name is the first column
1032+ col_line .append ( columns .names [i ] )
1033+
1034+ if isinstance (index_label ,list ) and len (index_label )> 1 :
1035+ col_line .extend ([ '' ] * (len (index_label )- 1 ))
1036+
1037+ col_line .extend (columns .get_level_values (i ))
1038+
1039+ writer .writerow (col_line )
1040+
1041+ # add blanks for the columns, so that we
1042+ # have consistent seps
1043+ encoded_labels .extend ([ '' ] * len (columns ))
1044+
1045+ # write out the index label line
1046+ writer .writerow (encoded_labels )
10031047
10041048 def _save (self ):
10051049
0 commit comments