根据这道题总结一下快速排序和堆排序,再根据这两种方法写这道题。
给定整数数组 nums
和整数 k
,请返回数组中第 k
个最大的元素。
请注意,你需要找的是数组排序后的第 k
个最大的元素,而不是第 k
个不同的元素。
你必须设计并实现时间复杂度为 O(n)
的算法解决此问题。
示例 1:
输入: [3,2,1,5,6,4]
, k = 2
输出: 5
示例 2:
输入: [3,2,3,1,2,4,5,5,6]
, k = 4
输出: 4
提示:
1 <= k <= nums.length <= 105
-104 <= nums[i] <= 104
我们首先给出快速排序的代码,快速排序的思路是先选取一个基准值,然后把小于基准值的放到基准值左边,把大于基准值的放到基准值右边,这样就会变成三部分(基准值左边部分、基准值、基准值右边部分),对基准值左右再递归进行这个步骤。代码分三部分:快速排序辅助分区部分、排序部分和主函数,分区部分就是把比基准值小的放左边,比基准值大的放右边,然后把基准值放中间,排序部分就是递归排序。
#include <iostream>
#include <vector>
#include <utility> // for std::swap// 快速排序的辅助函数,进行分区
int partition(std::vector<int> &nums, int low, int high) {// 选择最左侧的元素作为基准值(pivot)int pivot = nums[low];int i = low + 1; // i指针用来记录比基准值小的区域的最后一个元素的位置int j = high; // j指针用来记录比基准值大的区域的第一个元素的位置// 循环进行分区操作while(true) {// 从左向右找,找到大于等于基准值的元素while (nums[i] < pivot) {i++;}// 从右向左找,找到小于等于基准值的元素while (nums[j] > pivot) {j--;}if (i < j) {std::swap(nums[i], nums[j]);} else {// 完成分区,左边全是小于等于基准值,右边全是大于等于基准值break;}}// 交换基准值到分区的中间std::swap(nums[low], nums[j]);// 返回基准值的最终位置return i;
}// 快速排序的递归函数
void quickSort(std::vector<int> &nums, int low, int high) {if (low < high) {// 分区操作int pivotIndex = partition(nums, low, high);// 对基准值左边的子序列进行快速排序quickSort(nums, low, pivotIndex - 1);// 对基准值右边的子序列进行快速排序quickSort(nums, pivotIndex + 1, high);}
}int main() {std::vector<int> nums = {10, 7, 8, 9, 1, 5};int n = nums.size();quickSort(nums, 0, n - 1);for (int num : nums) {std::cout << num << " ";}return 0;
}
运行结果(每一步分区的过程)为:
6 7 8 9 1 5 3 3 6 1 10
6 1 3 3 1 5 6 9 8 7 10
5 1 3 3 1 6 6 9 8 7 10
1 1 3 3 5 6 6 9 8 7 10
1 1 3 3 5 6 6 9 8 7 10
1 1 3 3 5 6 6 9 8 7 10
1 1 3 3 5 6 6 7 8 9 10
1 1 3 3 5 6 6 7 8 9 10
1 1 3 3 5 6 6 7 8 9 10
快速排序的时间复杂度是O(nlogn)
。
基于快速排序可以写出做这道题的快速选择方法的代码,与快速排序一样,需要先分区,之后确定了基准值最终所在的位置,然后不需要进行排序操作,只需要知道第k大的元素是在基准值左边还是右边,然后在那个分区找就可以了,也是递归来查找,这个就是在快排的过程中直接找到了,所以不需要进行完整的快排,因此复杂度变低:
#include <iostream>
#include <vector>
#include <utility> // for std::swap// 快速排序的辅助函数,进行分区
int partition(std::vector<int> &nums, int low, int high) {// 选择最左侧的元素作为基准值(pivot)int pivot = nums[low];int i = low + 1; // i指针用来记录比基准值小的区域的最后一个元素的位置int j = high; // j指针用来记录比基准值大的区域的第一个元素的位置// 循环进行分区操作while(true) {// 从左向右找,找到大于等于基准值的元素while (nums[i] < pivot) {i++;}// 从右向左找,找到小于等于基准值的元素while (nums[j] > pivot) {j--;}if (i < j) {std::swap(nums[i], nums[j]);} else {// 完成分区,左边全是小于等于基准值,右边全是大于等于基准值break;}}// 交换基准值到分区的中间std::swap(nums[low], nums[j]);// 返回基准值的最终位置return j;
}// 快速排序的递归函数
int quickSelect(std::vector<int> &nums, int low, int high, int kIndex) {if (low == high) {// 当子数组只有一个元素时,返回该元素return nums[low];}int pivotIndex = partition(nums, low, high);if (kIndex <= pivotIndex) {// 第k大的元素索引在左侧子数组中return quickSelect(nums, low, pivotIndex, kIndex);} else {// 第k大的元素索引在右侧子数组中return quickSelect(nums, pivotIndex + 1, high, kIndex);}
}int main() {std::vector<int> nums = {10, 5, 3, 2, 1, 6, 8, 7};int n = nums.size();int k = 3;// 第k大的元素的索引是k-1int kIndex = k - 1;int ans = quickSelect(nums, 0, n - 1, n - 1 - kIndex);std::cout << "The ans is " << ans << std::endl;return 0;
}
注意,当求第k
大的元素时,传入的是索引k-1
,当求第k
小的元素(第n-k+1
大)时,传入索引n-k
(即n-1-kIndex
)。这个方法时间复杂度是O(n)
。
下面来总结一下堆排序和这道题,我们给出堆排序的代码:
#include <iostream>
#include <vector>
#include <algorithm> // for std::swap// 自上向下调整堆,保证堆的性质
void heapify(std::vector<int> &nums, int n, int i) {int largest = i; // 初始时假设当前节点为最大值int left = 2 * i + 1; // 左子节点int right = 2 * i + 2; // 右子节点// 如果左子节点存在且大于当前节点,更新最大值节点if (left < n && nums[left] > nums[largest]) {largest = left;}// 如果右子节点存在且大于当前节点,更新最大值节点if (right < n && nums[right] > nums[largest]) {largest = right;}// 如果最大值节点发生了变化,交换当前节点和最大值节点的值,并继续调整if (largest != i) {std::swap(nums[i], nums[largest]);heapify(nums, n, largest);}
}// 堆排序
void heapSort(std::vector<int> &nums) {int n = nums.size();// 从最后一个非叶子节点开始建堆,即从 (n/2 - 1) 节点开始for (int i = n / 2 - 1; i >= 0; i--) {heapify(nums, n, i);}// 从最后一个元素开始,交换元素并进行调整堆操作for (int i = n - 1; i > 0; i--) {std::swap(nums[0], nums[i]); // 将当前堆的最大值放到数组末尾heapify(nums, i, 0); // 调整堆,新的堆大小为 i}
}int main() {std::vector<int> nums = {16, 10, 8, 7, 2, 3, 4, 1, 9, 14};heapSort(nums);std::cout << "Sorted array: ";for (int num : nums) {std::cout << num << " ";}return 0;
}
堆排序有三个重要部分:维护堆的性质,建堆,排序。以大根堆为例,这是一颗完全二叉树,父节点的值大于子节点的值,下标为i
的节点的父节点下标是(i - 1) / 2
(整数除法),下标为i
的节点的左孩子下标是i * 2 + 1
,右孩子下标是i * 2 + 2
,因此,假如有n
个元素,那么堆的最后一个非叶子节点的下标是n / 2 - 1
。
- 维护堆的性质,即为保证父节点值大于子节点值,从上而下调整,比如当前
i
节点不满足这个性质,那么交换i
节点和它的左右孩子中最大的那个,然后再判断子节点那里是否满足堆的性质(之所以需要这样是因为如果进行了交换,那么子节点那里可能会发生变化,比如3 6 5 2 4
这个情况,首先3
和6
进行了交换,变成了6 3 5 2 4
,那么3 2 4
那个部分(之前是6 2 4
)就需要再次进行交换)。 - 建堆,即从最后一个非叶子节点开始,自下而上维护堆的性质,直到根节点。
- 堆排序,将当前堆的最大值放到数组末尾,然后把它排除出去,再从根向下进行堆的维护,新的堆的大小为
n-1
,重复这个过程,直到只剩一个元素。
运行结果为:
create heap
16 14 8 9 10 3 4 1 7 2
sort heap for 9 nums 14 10 8 9 2 3 4 1 7 16
sort heap for 8 nums 10 9 8 7 2 3 4 1 14 16
sort heap for 7 nums 9 7 8 1 2 3 4 10 14 16
sort heap for 6 nums 8 7 4 1 2 3 9 10 14 16
sort heap for 5 nums 7 3 4 1 2 8 9 10 14 16
sort heap for 4 nums 4 3 2 1 7 8 9 10 14 16
sort heap for 3 nums 3 1 2 4 7 8 9 10 14 16
sort heap for 2 nums 2 1 3 4 7 8 9 10 14 16
sort heap for 1 nums 1 2 3 4 7 8 9 10 14 16
sorted heap
1 2 3 4 7 8 9 10 14 16
可以看到16, 10, 8, 7, 2, 3, 4, 1, 9, 14
经过建堆过程(自最后一个非叶子节点向上维护堆),变成了16 14 8 9 10 3 4 1 7 2
,然后需要进行堆排序,将16
和14
交换,然后不管16
了,这个时候它是最后一个元素,再从根向下维护堆,得到14 10 8 9 2 3 4 1 7 16
,然后再将14
和7
交换,进行相同的步骤,最后排序成功。
有了堆排序的基础,我们利用堆排序解决数组中的第K
个最大元素的问题,事实上,在堆排序取最大值的过程中,已经体现出来了,在第一次取16
,这就是第1
大的元素,第二次取14
就是第2
大的元素,那么我们想得到第k
大元素的值,只需要设置堆排序的停止条件为i > n - k
,然后这时候的nums[0]
(即根节点值)为第k
大的元素。如果我们想得到第’k’小的元素,那么就取第n-k+1
大的元素。
详细代码如下:
#include <iostream>
#include <vector>
#include <algorithm> // for std::swap// 自上向下调整堆,保证堆的性质
void heapify(std::vector<int> &nums, int n, int i) {int largest = i; // 初始时假设当前节点为最大值int left = 2 * i + 1; // 左子节点int right = 2 * i + 2; // 右子节点// 如果左子节点存在且大于当前节点,更新最大值节点if (left < n && nums[left] > nums[largest]) {largest = left;}// 如果右子节点存在且大于当前节点,更新最大值节点if (right < n && nums[right] > nums[largest]) {largest = right;}// 如果最大值节点发生了变化,交换当前节点和最大值节点的值,并继续调整if (largest != i) {std::swap(nums[i], nums[largest]);heapify(nums, n, largest);}
}// 堆排序取数
int heapSelect(std::vector<int> &nums, int k) {int n = nums.size();// 从最后一个非叶子节点开始建堆,即从 (n/2 - 1) 节点开始for (int i = n / 2 - 1; i >= 0; i--) {heapify(nums, n, i);}std::cout << "create heap" << std::endl;for (int num : nums) {std::cout << num << " ";}std::cout << "\n";// 从最后一个元素开始,交换元素并进行调整堆操作for (int i = n - 1; i > n - k; i--) {std::swap(nums[0], nums[i]); // 将当前堆的最大值放到数组末尾heapify(nums, i, 0); // 调整堆,新的堆大小为 istd::cout << "sort heap for " << i << " nums" << " ";for (int num : nums) {std::cout << num << " ";}std::cout << "\n";}std::cout << "sorted heap" << std::endl; for (int num : nums) {std::cout << num << " ";}return nums[0];
}int main() {std::vector<int> nums = {16, 10, 8, 7, 2, 3, 4, 1, 9, 14};int n = nums.size();int k = 4;int ans = heapSelect(nums, n - k + 1);std::cout << "ans=" << ans << std::endl;return 0;
}
运行结果:
create heap
16 14 8 9 10 3 4 1 7 2
sort heap for 9 nums 14 10 8 9 2 3 4 1 7 16
sort heap for 8 nums 10 9 8 7 2 3 4 1 14 16
sort heap for 7 nums 9 7 8 1 2 3 4 10 14 16
sort heap for 6 nums 8 7 4 1 2 3 9 10 14 16
sort heap for 5 nums 7 3 4 1 2 8 9 10 14 16
sort heap for 4 nums 4 3 2 1 7 8 9 10 14 16
sorted heap
4 3 2 1 7 8 9 10 14 16 ans=4
时间复杂度是O(nlogn)
,建堆的复杂度是O(n)
,删除堆顶元素的复杂度是O(klogn)
,所以总共的时间复杂度是O(n+klogn)=O(nlogn)
。