Skip to content

0215 数组中的第K个最大元素

中等 快排

算法思路

寻找第 \(K\) 大(小)数属于是经典算法题了,应该能立即想到快速排序

其主要思想是,将数组 arr[0,n] 分成 [0,p-1] [p+1,n] 两部分,其中左部数组都比 arr[p] 小,右部都比 arr[p] 大。其中两部分数组内部不要求有序。

结合本题,使用快排思想,可以逐步简化问题,在递归中将答案缩小到数组左部或右部,最终获得所求。

代码实现

kth-largest-element-in-an-array.c
void swap(int* a, int* b) {
    int t = *a;
    *a = *b, *b = t;
}

int quickSort(int* nums, int k, int left, int right) {
    int num = nums[(left + right) / 2];
    swap(nums + (left + right) / 2, nums + right);
    int index = -1;
    int sameNum = 1;
    for (int i = left; i < right; i++) {
        if (nums[i] == num) {
            sameNum++;
        }
        if (nums[i] <= num && index < 0) {
            index = i;
        }
        if (nums[i] > num && index >= 0) {
            swap(nums + index, nums + i);
            index++;
        }
    }
    index = (index == -1) ? right : index;
    swap(nums + index, nums + right);
    if (k >= index + 1 && k <= index + sameNum) {
        return num;
    }
    if (k <= index) {
        return quickSort(nums, k, left, index - 1);
    } else {
        return quickSort(nums, k, index + 1, right);
    }
}

int findKthLargest(int* nums, int numsSize, int k) {
    return quickSort(nums, k, 0, numsSize - 1);
}

关于代码:

  • 小优化:考虑到可能存在大量相同元素,由于快排思想会将相同元素集中到一边,这会让我们选取的子数组缩小幅度很低,极大影响算法效率(时间复杂度容易陷入 \(O(n^2)\)),因此引入 sameNum,快速判断相同元素。
  • 易错:对 sameNum 的更新应在 swap 之前,避免重复判断。
  • 易错:注意对 index 的维护,即对于初始化值 -1 的更新。

HACK

尽管通过,依旧能想到对本代码的 hack。

例如,当 nums=[2,2,...,2,1,2]numsSize=10^5k=10^5 时,上述算法将相同元素移至右部(小于等于),当题目要求取最小值时,此代码将退化至接近\(O(n^2)\)复杂度。

不过,对于快排算法,只要运气足够差,总会存在测试样例(或随机化结果)使得其某次运行时间复杂度退化至\(O(n^2)\)。我们能做的,也只是尽量想出一些策略(例如随机选取划分元素等)去避免各种人为设置的非随机数据导致的退化,从而提高鲁棒性。

我们知道快速排序的性能和「划分」出的子数组的长度密切相关。直观地理解如果每次规模为 n 的问题我们都划分成 1 和 n−1,每次递归的时候又向 n−1 的集合中递归,这种情况是最坏的,时间代价是 \(O(n^2)\)

我们可以引入随机化来加速这个过程,它的时间代价的期望是 \(O(n)\),证明过程可以参考「《算法导论》9.2:期望为线性的选择算法」。需要注意的是,这个时间复杂度只有在随机数据下才成立,而对于精心构造的数据则可能表现不佳。

UPDATE

三路切分,解决大量重复元素引起的时间复杂度。

void swap(int* a, int* b) {
    int t = *a;
    *a = *b;
    *b = t;
}

// 寻找第 k 大(即排序后索引为 k-1 的元素)
int quickSelect(int* nums, int left, int right, int k) {
    if (left >= right) return nums[left];

    // 1. 随机选择 pivot 并交换到 left 位置
    int randIdx = left + rand() % (right - left + 1);
    swap(&nums[left], &nums[randIdx]);
    int pivot = nums[left];

    // 2. 三路切分 (3-way partition)
    // lt: less than pivot (实际上我们要降序,所以这里放比pivot大的)
    // gt: greater than pivot (放比pivot小的)
    // i: 当前遍历指针
    // 目标: [left...lt-1] > pivot, [lt...gt] == pivot, [gt+1...right] < pivot

    int lt = left;
    int gt = right;
    int i = left + 1;

    while (i <= gt) {
        if (nums[i] > pivot) {
            swap(&nums[i], &nums[lt]);
            lt++;
            i++;
        } else if (nums[i] < pivot) {
            swap(&nums[i], &nums[gt]);
            gt--;
        } else {
            i++;
        }
    }

    // 此时:
    // [left, lt-1] 是大于 pivot 的
    // [lt, gt] 是等于 pivot 的
    // [gt+1, right] 是小于 pivot 的

    // 3. 判断 k 在哪个区间
    if (k < lt) {
        return quickSelect(nums, left, lt - 1, k);
    } else if (k > gt) {
        return quickSelect(nums, gt + 1, right, k);
    } else {
        return pivot; // k 落在等于 pivot 的区间,直接返回
    }
}

int findKthLargest(int* nums, int numsSize, int k) {
    // 注意:题目通常说第 k 大,对应数组下标是 k-1
    return quickSelect(nums, 0, numsSize - 1, k - 1);
}