This article has an English version available.

快速排序的细节

快速排序的平均时间复杂度为 O(n log n)。然而在最坏情况下,快速排序会达到 O(n^2)。当选取的枢轴(pivot)是数组中的最小或最大元素时,就会出现最坏情况。此时分区会得到一个包含 n-1 个元素的子数组和一个包含 0 个元素的子数组,从而产生 N 层递归,时间复杂度为 O(n^2)。

可能导致最坏情况的两种场景:

  1. 全部元素相同的数组:此时枢轴始终与数组中所有元素相同,分区每次只能“移除”一个元素,递归树高度为 N,从而形成 O(N^2) 的最坏复杂度。
  2. 已经排好序的数组:此时枢轴要么是最小元素,要么是最大元素(取决于实现),分区得到一个大小为 0 的子数组和一个大小为 N-1 的子数组,同样产生 N 层递归和 O(N^2) 的最坏复杂度。

不过,下面这个经过精心设计的算法可以避免这两种最坏情况。它同样是原地算法,意味着只需要很少的额外内存。步骤如下:

  1. 选择一个枢轴元素,这里采用随机元素。
  2. 交换枢轴元素与数组第一个元素,此时第一个元素即为枢轴。
  3. 定义两个指针,一个指向数组第二个元素,另一个指向数组最后一个元素。
  4. 重复以下步骤,直到 left_pointer 大于等于 right_pointer
    1. left_pointer 向右移动,直到它指向一个大于等于枢轴的元素。
    2. right_pointer 向左移动,直到它指向一个小于等于枢轴的元素。
    3. 如果 left_pointer 小于等于 right_pointer,则交换两者指向的元素,并把两个指针同时向中间移动一步。
  5. 当两个指针交错后,将第一个枢轴元素与 right_pointer 所指元素交换。
  6. 对枢轴左右两侧的子数组递归应用同样的步骤。

全相等元素的数组

在所有元素相同的场景下,枢轴等于其他所有元素,所以在步骤 4.a 与 4.b 中,指针不会移动。于是 4.c 只会交换两个元素并让指针向中间靠拢,数组会被有效地分成两个规模相近的子数组。

因此,该算法能避免因为元素相等而导致的最坏情况,且会以 O(NlogN) 的时间复杂度完成排序,也就是其平均复杂度。

已排序数组

为避免该场景,应随机选择枢轴,而不是总选第一个或最后一个元素。随机选择枢轴能降低选到最小/最大元素的概率,从而更可能得到平衡的分区,使整体运行时间接近 O(NlogN) 的平均复杂度。

另一个细节

在步骤 4 中,while 循环条件包含 left_pointerright_pointer 指向同一元素的情况。这个元素可能大于枢轴,循环会执行一次后退出,从而保证 right_pointer 指向的元素小于等于枢轴,这样它就可以和枢轴交换。

代码

cpp
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <random>
using LL = long long;
using namespace std;

void get_piv(LL arr[], LL l, LL r){
    LL piv_id = rand() % (r - l) + l;
    swap(arr[l], arr[piv_id]);
}
void sort(LL arr[], LL l, LL r){
    if (r - l <= 1)
        return;
    get_piv(arr, l, r);
    LL piv = arr[l];
    LL tail = r-1, head = l+1;
    while (head <= tail){
        if (arr[head] < piv) ++head;
        else if (arr[tail] > piv) --tail;
        else{
            swap(arr[head], arr[tail]);
            head++, tail--;
        }
    }
    swap(arr[tail], arr[l]);
    sort(arr, l, tail);
    sort(arr, tail+1, r);
}
LL arr[200000];
int main (){
    srandom(time(NULL));
    LL n;
    scanf("%lld", &n);
    for (LL i = 0; i < n; ++i)
        scanf("%lld", &arr[i]);
    sort(arr, 0, n);
    for (LL i = 0; i < n; ++i)
        printf("%lld ", arr[i]);

}