22import pytest
33
44import pandas as pd
5+ from pandas import DataFrame , Series , date_range
56import pandas ._testing as tm
67
78
89class TestDataFrameTruncate :
9- def test_truncate (self , datetime_frame ):
10+ def test_truncate (self , datetime_frame , frame_or_series ):
1011 ts = datetime_frame [::3 ]
12+ if frame_or_series is Series :
13+ ts = ts .iloc [:, 0 ]
1114
1215 start , end = datetime_frame .index [3 ], datetime_frame .index [6 ]
1316
@@ -16,34 +19,41 @@ def test_truncate(self, datetime_frame):
1619
1720 # neither specified
1821 truncated = ts .truncate ()
19- tm .assert_frame_equal (truncated , ts )
22+ tm .assert_equal (truncated , ts )
2023
2124 # both specified
2225 expected = ts [1 :3 ]
2326
2427 truncated = ts .truncate (start , end )
25- tm .assert_frame_equal (truncated , expected )
28+ tm .assert_equal (truncated , expected )
2629
2730 truncated = ts .truncate (start_missing , end_missing )
28- tm .assert_frame_equal (truncated , expected )
31+ tm .assert_equal (truncated , expected )
2932
3033 # start specified
3134 expected = ts [1 :]
3235
3336 truncated = ts .truncate (before = start )
34- tm .assert_frame_equal (truncated , expected )
37+ tm .assert_equal (truncated , expected )
3538
3639 truncated = ts .truncate (before = start_missing )
37- tm .assert_frame_equal (truncated , expected )
40+ tm .assert_equal (truncated , expected )
3841
3942 # end specified
4043 expected = ts [:3 ]
4144
4245 truncated = ts .truncate (after = end )
43- tm .assert_frame_equal (truncated , expected )
46+ tm .assert_equal (truncated , expected )
4447
4548 truncated = ts .truncate (after = end_missing )
46- tm .assert_frame_equal (truncated , expected )
49+ tm .assert_equal (truncated , expected )
50+
51+ # corner case, empty series/frame returned
52+ truncated = ts .truncate (after = ts .index [0 ] - ts .index .freq )
53+ assert len (truncated ) == 0
54+
55+ truncated = ts .truncate (before = ts .index [- 1 ] + ts .index .freq )
56+ assert len (truncated ) == 0
4757
4858 msg = "Truncate: 2000-01-06 00:00:00 must be after 2000-02-04 00:00:00"
4959 with pytest .raises (ValueError , match = msg ):
@@ -57,25 +67,35 @@ def test_truncate_copy(self, datetime_frame):
5767 truncated .values [:] = 5.0
5868 assert not (datetime_frame .values [5 :11 ] == 5 ).any ()
5969
60- def test_truncate_nonsortedindex (self ):
70+ def test_truncate_nonsortedindex (self , frame_or_series ):
6171 # GH#17935
6272
63- df = pd .DataFrame ({"A" : ["a" , "b" , "c" , "d" , "e" ]}, index = [5 , 3 , 2 , 9 , 0 ])
73+ obj = DataFrame ({"A" : ["a" , "b" , "c" , "d" , "e" ]}, index = [5 , 3 , 2 , 9 , 0 ])
74+ if frame_or_series is Series :
75+ obj = obj ["A" ]
76+
6477 msg = "truncate requires a sorted index"
6578 with pytest .raises (ValueError , match = msg ):
66- df .truncate (before = 3 , after = 9 )
79+ obj .truncate (before = 3 , after = 9 )
80+
81+ def test_sort_values_nonsortedindex (self ):
82+ # TODO: belongs elsewhere?
6783
68- rng = pd . date_range ("2011-01-01" , "2012-01-01" , freq = "W" )
69- ts = pd . DataFrame (
84+ rng = date_range ("2011-01-01" , "2012-01-01" , freq = "W" )
85+ ts = DataFrame (
7086 {"A" : np .random .randn (len (rng )), "B" : np .random .randn (len (rng ))}, index = rng
7187 )
88+
7289 msg = "truncate requires a sorted index"
7390 with pytest .raises (ValueError , match = msg ):
7491 ts .sort_values ("A" , ascending = False ).truncate (
7592 before = "2011-11" , after = "2011-12"
7693 )
7794
78- df = pd .DataFrame (
95+ def test_truncate_nonsortedindex_axis1 (self ):
96+ # GH#17935
97+
98+ df = DataFrame (
7999 {
80100 3 : np .random .randn (5 ),
81101 20 : np .random .randn (5 ),
@@ -93,27 +113,34 @@ def test_truncate_nonsortedindex(self):
93113 [(1 , 2 , [2 , 1 ]), (None , 2 , [2 , 1 , 0 ]), (1 , None , [3 , 2 , 1 ])],
94114 )
95115 @pytest .mark .parametrize ("klass" , [pd .Int64Index , pd .DatetimeIndex ])
96- def test_truncate_decreasing_index (self , before , after , indices , klass ):
116+ def test_truncate_decreasing_index (
117+ self , before , after , indices , klass , frame_or_series
118+ ):
97119 # https://github.com/pandas-dev/pandas/issues/33756
98120 idx = klass ([3 , 2 , 1 , 0 ])
99121 if klass is pd .DatetimeIndex :
100122 before = pd .Timestamp (before ) if before is not None else None
101123 after = pd .Timestamp (after ) if after is not None else None
102124 indices = [pd .Timestamp (i ) for i in indices ]
103- values = pd . DataFrame (range (len (idx )), index = idx )
125+ values = frame_or_series (range (len (idx )), index = idx )
104126 result = values .truncate (before = before , after = after )
105127 expected = values .loc [indices ]
106- tm .assert_frame_equal (result , expected )
128+ tm .assert_equal (result , expected )
107129
108- def test_truncate_multiindex (self ):
130+ def test_truncate_multiindex (self , frame_or_series ):
109131 # GH 34564
110132 mi = pd .MultiIndex .from_product ([[1 , 2 , 3 , 4 ], ["A" , "B" ]], names = ["L1" , "L2" ])
111- s1 = pd .DataFrame (range (mi .shape [0 ]), index = mi , columns = ["col" ])
133+ s1 = DataFrame (range (mi .shape [0 ]), index = mi , columns = ["col" ])
134+ if frame_or_series is Series :
135+ s1 = s1 ["col" ]
136+
112137 result = s1 .truncate (before = 2 , after = 3 )
113138
114- df = pd . DataFrame .from_dict (
139+ df = DataFrame .from_dict (
115140 {"L1" : [2 , 2 , 3 , 3 ], "L2" : ["A" , "B" , "A" , "B" ], "col" : [2 , 3 , 4 , 5 ]}
116141 )
117142 expected = df .set_index (["L1" , "L2" ])
143+ if frame_or_series is Series :
144+ expected = expected ["col" ]
118145
119- tm .assert_frame_equal (result , expected )
146+ tm .assert_equal (result , expected )
0 commit comments