@@ -32,13 +32,13 @@ def partition(arr: list[int], low: int, high: int) -> int:
32
32
1
33
33
"""
34
34
pivot = arr [high ]
35
- i = low - 1
36
- for j in range (low , high ):
37
- if arr [j ] >= pivot :
38
- i += 1
39
- arr [i ], arr [j ] = arr [j ], arr [i ]
40
- arr [i + 1 ], arr [high ] = arr [high ], arr [i + 1 ]
41
- return i + 1
35
+ store_index = low - 1
36
+ for i in range (low , high ):
37
+ if arr [i ] >= pivot :
38
+ store_index += 1
39
+ arr [store_index ], arr [i ] = arr [i ], arr [store_index ]
40
+ arr [store_index + 1 ], arr [high ] = arr [high ], arr [store_index + 1 ]
41
+ return store_index + 1
42
42
43
43
44
44
def kth_largest_element (arr : list [int ], position : int ) -> int :
@@ -99,12 +99,11 @@ def kth_largest_element(arr, position):
99
99
raise ValueError ("Invalid value of 'position'" )
100
100
low , high = 0 , len (arr ) - 1
101
101
while low <= high :
102
- if low > len (arr ) - 1 or high < 0 :
103
- return - 1
104
102
pivot_index = partition (arr , low , high )
105
- if pivot_index == position - 1 :
103
+ target_index = position - 1
104
+ if pivot_index == target_index :
106
105
return arr [pivot_index ]
107
- elif pivot_index > position - 1 :
106
+ elif pivot_index > target_index :
108
107
high = pivot_index - 1
109
108
else :
110
109
low = pivot_index + 1
0 commit comments