@@ -205,9 +205,15 @@ def _args_adjust(self):
205205 def _setup_subplots (self ):
206206 if self .subplots :
207207 nrows , ncols = self ._get_layout ()
208- fig , axes = _subplots (nrows = nrows , ncols = ncols ,
209- sharex = self .sharex , sharey = self .sharey ,
210- figsize = self .figsize )
208+ if self .ax is None :
209+ fig , axes = _subplots (nrows = nrows , ncols = ncols ,
210+ sharex = self .sharex , sharey = self .sharey ,
211+ figsize = self .figsize )
212+ else :
213+ fig , axes = _subplots (nrows = nrows , ncols = ncols ,
214+ sharex = self .sharex , sharey = self .sharey ,
215+ figsize = self .figsize , ax = self .ax )
216+
211217 else :
212218 if self .ax is None :
213219 fig = self .plt .figure (figsize = self .figsize )
@@ -509,10 +515,13 @@ def plot_frame(frame=None, subplots=False, sharex=True, sharey=False,
509515 -------
510516 ax_or_axes : matplotlib.AxesSubplot or list of them
511517 """
518+ kind = kind .lower ().strip ()
512519 if kind == 'line' :
513520 klass = LinePlot
514521 elif kind in ('bar' , 'barh' ):
515522 klass = BarPlot
523+ else :
524+ raise ValueError ('Invalid chart type given %s' % kind )
516525
517526 plot_obj = klass (frame , kind = kind , subplots = subplots , rot = rot ,
518527 legend = legend , ax = ax , fontsize = fontsize ,
@@ -691,49 +700,84 @@ def plot_group(group, ax):
691700 ax .scatter (xvals , yvals )
692701
693702 if by is not None :
694- fig = _grouped_plot (plot_group , data , by = by , figsize = figsize )
703+ fig = _grouped_plot (plot_group , data , by = by , figsize = figsize , ax = ax )
695704 else :
696- fig = plt .figure ()
697- ax = fig .add_subplot (111 )
705+ if ax is None :
706+ fig = plt .figure ()
707+ ax = fig .add_subplot (111 )
708+ else :
709+ fig = ax .get_figure ()
698710 plot_group (data , ax )
699711 ax .set_ylabel (str (y ))
700712 ax .set_xlabel (str (x ))
701713
702714 return fig
703715
704716
705- def hist_frame (data , grid = True , ** kwds ):
717+ def hist_frame (data , grid = True , xlabelsize = None , xrot = None ,
718+ ylabelsize = None , yrot = None , ax = None , ** kwds ):
706719 """
707720 Draw Histogram the DataFrame's series using matplotlib / pylab.
708721
709722 Parameters
710723 ----------
724+ grid : boolean, default True
725+ Whether to show axis grid lines
726+ xlabelsize : int, default None
727+ If specified changes the x-axis label size
728+ xrot : float, default None
729+ rotation of x axis labels
730+ ylabelsize : int, default None
731+ If specified changes the y-axis label size
732+ yrot : float, default None
733+ rotation of y axis labels
734+ ax : matplotlib axes object, default None
711735 kwds : other plotting keyword arguments
712736 To be passed to hist function
713737 """
738+ import matplotlib .pyplot as plt
714739 n = len (data .columns )
715740 k = 1
716741 while k ** 2 < n :
717742 k += 1
718- _ , axes = _subplots (nrows = k , ncols = k )
743+ _ , axes = _subplots (nrows = k , ncols = k , ax = ax )
719744
720745 for i , col in enumerate (com ._try_sort (data .columns )):
721746 ax = axes [i / k ][i % k ]
722747 ax .hist (data [col ].dropna ().values , ** kwds )
723748 ax .set_title (col )
724749 ax .grid (grid )
725750
726- return axes
751+ if xlabelsize is not None :
752+ plt .setp (ax .get_xticklabels (), fontsize = xlabelsize )
753+ if xrot is not None :
754+ plt .setp (ax .get_xticklabels (), rotation = xrot )
755+ if ylabelsize is not None :
756+ plt .setp (ax .get_yticklabels (), fontsize = ylabelsize )
757+ if yrot is not None :
758+ plt .setp (ax .get_yticklabels (), rotation = yrot )
727759
760+ return axes
728761
729- def hist_series (self , ax = None , grid = True , ** kwds ):
762+ def hist_series (self , ax = None , grid = True , xlabelsize = None , xrot = None ,
763+ ylabelsize = None , yrot = None , ** kwds ):
730764 """
731765 Draw histogram of the input series using matplotlib
732766
733767 Parameters
734768 ----------
735769 ax : matplotlib axis object
736770 If not passed, uses gca()
771+ grid : boolean, default True
772+ Whether to show axis grid lines
773+ xlabelsize : int, default None
774+ If specified changes the x-axis label size
775+ xrot : float, default None
776+ rotation of x axis labels
777+ ylabelsize : int, default None
778+ If specified changes the y-axis label size
779+ yrot : float, default None
780+ rotation of y axis labels
737781 kwds : keywords
738782 To be passed to the actual plotting function
739783
@@ -752,12 +796,21 @@ def hist_series(self, ax=None, grid=True, **kwds):
752796 ax .hist (values , ** kwds )
753797 ax .grid (grid )
754798
799+ if xlabelsize is not None :
800+ plt .setp (ax .get_xticklabels (), fontsize = xlabelsize )
801+ if xrot is not None :
802+ plt .setp (ax .get_xticklabels (), rotation = xrot )
803+ if ylabelsize is not None :
804+ plt .setp (ax .get_yticklabels (), fontsize = ylabelsize )
805+ if yrot is not None :
806+ plt .setp (ax .get_yticklabels (), rotation = yrot )
807+
755808 return ax
756809
757810
758811def _grouped_plot (plotf , data , column = None , by = None , numeric_only = True ,
759812 figsize = None , sharex = True , sharey = True , layout = None ,
760- rot = 0 ):
813+ rot = 0 , ax = None ):
761814 from pandas .core .frame import DataFrame
762815
763816 # allow to specify mpl default with 'default'
@@ -777,7 +830,7 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
777830 # default size
778831 figsize = (10 , 5 )
779832 fig , axes = _subplots (nrows = nrows , ncols = ncols , figsize = figsize ,
780- sharex = sharex , sharey = sharey )
833+ sharex = sharex , sharey = sharey , ax = ax )
781834
782835 ravel_axes = []
783836 for row in axes :
@@ -794,7 +847,7 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
794847
795848def _grouped_plot_by_column (plotf , data , columns = None , by = None ,
796849 numeric_only = True , grid = False ,
797- figsize = None ):
850+ figsize = None , ax = None ):
798851 import matplotlib .pyplot as plt
799852
800853 grouped = data .groupby (by )
@@ -805,7 +858,7 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None,
805858 nrows , ncols = _get_layout (ngroups )
806859 fig , axes = _subplots (nrows = nrows , ncols = ncols ,
807860 sharex = True , sharey = True ,
808- figsize = figsize )
861+ figsize = figsize , ax = ax )
809862
810863 if isinstance (axes , plt .Axes ):
811864 ravel_axes = [axes ]
@@ -850,7 +903,7 @@ def _get_layout(nplots):
850903# copied from matplotlib/pyplot.py for compatibility with matplotlib < 1.0
851904
852905def _subplots (nrows = 1 , ncols = 1 , sharex = False , sharey = False , squeeze = True ,
853- subplot_kw = None , ** fig_kw ):
906+ subplot_kw = None , ax = None , ** fig_kw ):
854907 """Create a figure with a set of subplots already made.
855908
856909 This utility wrapper makes it convenient to create common layouts of
@@ -890,6 +943,8 @@ def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
890943 Dict with keywords passed to the figure() call. Note that all keywords
891944 not recognized above will be automatically included here.
892945
946+ ax : Matplotlib axis object, default None
947+
893948 Returns:
894949
895950 fig, ax : tuple
@@ -922,7 +977,10 @@ def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
922977 if subplot_kw is None :
923978 subplot_kw = {}
924979
925- fig = plt .figure (** fig_kw )
980+ if ax is None :
981+ fig = plt .figure (** fig_kw )
982+ else :
983+ fig = ax .get_figure ()
926984
927985 # Create empty object array to hold all axes. It's easiest to make it 1-d
928986 # so we can just append subplots upon creation, and then
0 commit comments