0215 数组中的第K个最大元素
中等 快排
算法思路
寻找第 \(K\) 大(小)数属于是经典算法题了,应该能立即想到快速排序。
其主要思想是,将数组 arr[0,n] 分成 [0,p-1] [p+1,n] 两部分,其中左部数组都比 arr[p] 小,右部都比 arr[p] 大。其中两部分数组内部不要求有序。
结合本题,使用快排思想,可以逐步简化问题,在递归中将答案缩小到数组左部或右部,最终获得所求。
代码实现
kth-largest-element-in-an-array.c
void swap(int* a, int* b) {
int t = *a;
*a = *b, *b = t;
}
int quickSort(int* nums, int k, int left, int right) {
int num = nums[(left + right) / 2];
swap(nums + (left + right) / 2, nums + right);
int index = -1;
int sameNum = 1;
for (int i = left; i < right; i++) {
if (nums[i] == num) {
sameNum++;
}
if (nums[i] <= num && index < 0) {
index = i;
}
if (nums[i] > num && index >= 0) {
swap(nums + index, nums + i);
index++;
}
}
index = (index == -1) ? right : index;
swap(nums + index, nums + right);
if (k >= index + 1 && k <= index + sameNum) {
return num;
}
if (k <= index) {
return quickSort(nums, k, left, index - 1);
} else {
return quickSort(nums, k, index + 1, right);
}
}
int findKthLargest(int* nums, int numsSize, int k) {
return quickSort(nums, k, 0, numsSize - 1);
}
关于代码:
- 小优化:考虑到可能存在大量相同元素,由于快排思想会将相同元素集中到一边,这会让我们选取的子数组缩小幅度很低,极大影响算法效率(时间复杂度容易陷入 \(O(n^2)\)),因此引入
sameNum,快速判断相同元素。 - 易错:对
sameNum的更新应在swap之前,避免重复判断。 - 易错:注意对
index的维护,即对于初始化值-1的更新。
HACK
尽管通过,依旧能想到对本代码的 hack。
例如,当 nums=[2,2,...,2,1,2],numsSize=10^5,k=10^5 时,上述算法将相同元素移至右部(小于等于),当题目要求取最小值时,此代码将退化至接近\(O(n^2)\)复杂度。
不过,对于快排算法,只要运气足够差,总会存在测试样例(或随机化结果)使得其某次运行时间复杂度退化至\(O(n^2)\)。我们能做的,也只是尽量想出一些策略(例如随机选取划分元素等)去避免各种人为设置的非随机数据导致的退化,从而提高鲁棒性。
我们知道快速排序的性能和「划分」出的子数组的长度密切相关。直观地理解如果每次规模为 n 的问题我们都划分成 1 和 n−1,每次递归的时候又向 n−1 的集合中递归,这种情况是最坏的,时间代价是 \(O(n^2)\)。
我们可以引入随机化来加速这个过程,它的时间代价的期望是 \(O(n)\),证明过程可以参考「《算法导论》9.2:期望为线性的选择算法」。需要注意的是,这个时间复杂度只有在随机数据下才成立,而对于精心构造的数据则可能表现不佳。
UPDATE
三路切分,解决大量重复元素引起的时间复杂度。
void swap(int* a, int* b) {
int t = *a;
*a = *b;
*b = t;
}
// 寻找第 k 大(即排序后索引为 k-1 的元素)
int quickSelect(int* nums, int left, int right, int k) {
if (left >= right) return nums[left];
// 1. 随机选择 pivot 并交换到 left 位置
int randIdx = left + rand() % (right - left + 1);
swap(&nums[left], &nums[randIdx]);
int pivot = nums[left];
// 2. 三路切分 (3-way partition)
// lt: less than pivot (实际上我们要降序,所以这里放比pivot大的)
// gt: greater than pivot (放比pivot小的)
// i: 当前遍历指针
// 目标: [left...lt-1] > pivot, [lt...gt] == pivot, [gt+1...right] < pivot
int lt = left;
int gt = right;
int i = left + 1;
while (i <= gt) {
if (nums[i] > pivot) {
swap(&nums[i], &nums[lt]);
lt++;
i++;
} else if (nums[i] < pivot) {
swap(&nums[i], &nums[gt]);
gt--;
} else {
i++;
}
}
// 此时:
// [left, lt-1] 是大于 pivot 的
// [lt, gt] 是等于 pivot 的
// [gt+1, right] 是小于 pivot 的
// 3. 判断 k 在哪个区间
if (k < lt) {
return quickSelect(nums, left, lt - 1, k);
} else if (k > gt) {
return quickSelect(nums, gt + 1, right, k);
} else {
return pivot; // k 落在等于 pivot 的区间,直接返回
}
}
int findKthLargest(int* nums, int numsSize, int k) {
// 注意:题目通常说第 k 大,对应数组下标是 k-1
return quickSelect(nums, 0, numsSize - 1, k - 1);
}