MK
摩柯社区 - 一个极简的技术知识社区
AI 面试

C++ STL 算法 sort 的性能优化策略

2023-09-207.8k 阅读

1. 理解 C++ STL 的 sort 算法

C++ 标准模板库(STL)中的 sort 算法是一个强大的工具,用于对给定范围内的元素进行排序。sort 函数定义在 <algorithm> 头文件中,其原型如下:

template<class RandomAccessIterator>
void sort (RandomAccessIterator first, RandomAccessIterator last);

template<class RandomAccessIterator, class Compare>
void sort (RandomAccessIterator first, RandomAccessIterator last, Compare comp);

第一个版本使用元素类型的 < 运算符进行比较,第二个版本则允许用户提供自定义的比较函数 compsort 算法通常采用一种混合策略,常见的是结合快速排序(Quicksort)、插入排序(Insertion Sort)和堆排序(Heapsort)。

2. 影响 sort 性能的因素

2.1 数据规模

数据规模是影响 sort 性能的一个关键因素。对于小数据集,插入排序由于其简单性和局部性好的特点,在常数时间开销方面表现出色。而对于大数据集,快速排序的平均时间复杂度为 (O(n log n)),通常能展现出较好的性能。然而,快速排序在最坏情况下(如数据已经有序且采用固定的轴点选择策略)时间复杂度会退化到 (O(n^2)),此时堆排序的 (O(n log n)) 最坏时间复杂度特性就显得尤为重要。

例如,以下代码比较了对不同规模数组进行排序的时间:

#include <iostream>
#include <algorithm>
#include <chrono>
#include <vector>

void test_sort(int size) {
    std::vector<int> vec(size);
    for (int i = 0; i < size; ++i) {
        vec[i] = rand();
    }

    auto start = std::chrono::high_resolution_clock::now();
    std::sort(vec.begin(), vec.end());
    auto end = std::chrono::high_resolution_clock::now();

    auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
    std::cout << "Sorting " << size << " elements took " << duration << " milliseconds." << std::endl;
}

int main() {
    test_sort(1000);
    test_sort(10000);
    test_sort(100000);
    return 0;
}

2.2 数据分布

数据的分布情况也对 sort 性能有显著影响。如果数据已经部分有序或完全有序,使用默认轴点选择策略的快速排序可能会遇到最坏情况。例如,对于一个已经升序排列的数组,若每次选择第一个元素作为轴点,快速排序会在每次划分时将数组划分为一个很大的子数组和一个空的子数组,导致时间复杂度退化。

为了应对这种情况,可以采用随机化轴点选择策略,即每次从待排序区间中随机选择一个元素作为轴点。下面是一个简单的实现:

#include <iostream>
#include <algorithm>
#include <vector>
#include <cstdlib>
#include <ctime>

template <typename RandomAccessIterator>
void my_sort(RandomAccessIterator first, RandomAccessIterator last) {
    if (first >= last) return;

    auto pivot = first + std::distance(first, last) / 2;
    std::swap(*pivot, *(last - 1));
    auto new_pivot = std::partition(first, last - 1, [&](const auto& a) {
        return a < *(last - 1);
    });
    std::swap(*new_pivot, *(last - 1));

    my_sort(first, new_pivot);
    my_sort(new_pivot + 1, last);
}

int main() {
    std::srand(std::time(nullptr));
    std::vector<int> vec = {3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5};
    my_sort(vec.begin(), vec.end());
    for (int num : vec) {
        std::cout << num << " ";
    }
    std::cout << std::endl;
    return 0;
}

2.3 比较函数的复杂度

当使用自定义比较函数时,比较函数的复杂度会直接影响 sort 的性能。如果比较函数的时间复杂度较高,如 (O(n)),则整个排序过程的时间复杂度将至少为 (O(n^2)),因为 sort 算法通常需要进行 (O(n log n)) 次比较。

例如,假设我们有一个复杂的自定义比较函数,它需要遍历一个链表来进行比较:

struct ListNode {
    int val;
    ListNode* next;
    ListNode(int x) : val(x), next(nullptr) {}
};

bool complex_compare(ListNode* a, ListNode* b) {
    // 这里假设需要遍历链表来比较
    ListNode* currA = a;
    ListNode* currB = b;
    while (currA && currB) {
        if (currA->val != currB->val) {
            return currA->val < currB->val;
        }
        currA = currA->next;
        currB = currB->next;
    }
    if (!currA && currB) return true;
    return false;
}

// 假设这里有一个函数将数组转换为链表
ListNode* arrayToLinkedList(int arr[], int size) {
    if (size == 0) return nullptr;
    ListNode* head = new ListNode(arr[0]);
    ListNode* curr = head;
    for (int i = 1; i < size; ++i) {
        curr->next = new ListNode(arr[i]);
        curr = curr->next;
    }
    return head;
}

// 假设这里有一个函数将链表转换回数组
void linkedListToArray(ListNode* head, int arr[], int size) {
    ListNode* curr = head;
    for (int i = 0; i < size; ++i) {
        arr[i] = curr->val;
        curr = curr->next;
    }
}

int main() {
    int arr[] = {3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5};
    int size = sizeof(arr) / sizeof(arr[0]);
    ListNode* head = arrayToLinkedList(arr, size);
    // 这里如果用这样的比较函数进行排序会非常低效
    // 实际中不太可能这样直接用,但为了说明问题
    // 假设我们要对链表指针数组进行排序
    std::vector<ListNode*> listPtrs;
    ListNode* curr = head;
    while (curr) {
        listPtrs.push_back(curr);
        curr = curr->next;
    }
    std::sort(listPtrs.begin(), listPtrs.end(), complex_compare);
    linkedListToArray(listPtrs[0], arr, size);
    for (int i = 0; i < size; ++i) {
        std::cout << arr[i] << " ";
    }
    std::cout << std::endl;
    return 0;
}

3. 性能优化策略

3.1 选择合适的轴点选择策略

如前文所述,随机化轴点选择策略可以有效避免快速排序在某些数据分布下的最坏情况。在 STL 的 sort 实现中,通常会采用三数取中(Median-of-Three)策略。该策略从待排序区间的开头、中间和末尾选取三个元素,然后将中间值作为轴点。这样做可以在一定程度上平衡轴点的选择,减少最坏情况的发生概率。

以下是一个简单的三数取中轴点选择的实现:

template <typename RandomAccessIterator>
typename std::iterator_traits<RandomAccessIterator>::value_type
median_of_three(RandomAccessIterator first, RandomAccessIterator last) {
    auto mid = first + std::distance(first, last) / 2;
    if ((*first < *mid) == (*mid < *last)) {
        return *mid;
    } else if ((*last < *mid) == (*mid < *first)) {
        return *mid;
    } else if ((*first < *last) == (*last < *mid)) {
        return *last;
    } else {
        return *first;
    }
}

template <typename RandomAccessIterator>
void my_sort_with_median(RandomAccessIterator first, RandomAccessIterator last) {
    if (first >= last) return;

    auto pivot_value = median_of_three(first, last);
    auto new_pivot = std::partition(first, last, [&](const auto& a) {
        return a < pivot_value;
    });

    my_sort_with_median(first, new_pivot);
    my_sort_with_median(new_pivot, last);
}

int main() {
    std::vector<int> vec = {3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5};
    my_sort_with_median(vec.begin(), vec.end());
    for (int num : vec) {
        std::cout << num << " ";
    }
    std::cout << std::endl;
    return 0;
}

3.2 针对小数据集使用插入排序

对于小数据集,插入排序的常数时间开销小且局部性好。在 STL 的 sort 实现中,当子数组的大小小于某个阈值(通常为 16)时,会切换到插入排序。这是因为插入排序在小数据集上的性能优势可以弥补其 (O(n^2)) 的时间复杂度在大数据集上的劣势。

下面是一个结合插入排序的 sort 实现:

template <typename RandomAccessIterator>
void insertion_sort(RandomAccessIterator first, RandomAccessIterator last) {
    for (auto i = first + 1; i < last; ++i) {
        auto key = *i;
        auto j = i;
        while (j > first && *(j - 1) > key) {
            *j = *(j - 1);
            --j;
        }
        *j = key;
    }
}

template <typename RandomAccessIterator>
void my_sort_with_insertion(RandomAccessIterator first, RandomAccessIterator last) {
    const int threshold = 16;
    if (std::distance(first, last) <= threshold) {
        insertion_sort(first, last);
        return;
    }

    auto pivot = first + std::distance(first, last) / 2;
    std::swap(*pivot, *(last - 1));
    auto new_pivot = std::partition(first, last - 1, [&](const auto& a) {
        return a < *(last - 1);
    });
    std::swap(*new_pivot, *(last - 1));

    my_sort_with_insertion(first, new_pivot);
    my_sort_with_insertion(new_pivot + 1, last);
}

int main() {
    std::vector<int> vec = {3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5};
    my_sort_with_insertion(vec.begin(), vec.end());
    for (int num : vec) {
        std::cout << num << " ";
    }
    std::cout << std::endl;
    return 0;
}

3.3 减少比较函数的开销

如果使用自定义比较函数,要确保其时间复杂度尽可能低。尽量避免在比较函数中进行复杂的计算或操作。例如,如果比较函数需要访问外部资源或进行大量的内存访问,考虑是否可以提前预处理数据,使得比较函数只进行简单的比较操作。

假设我们有一个类 Person,需要按年龄和姓名排序:

#include <iostream>
#include <algorithm>
#include <string>
#include <vector>

struct Person {
    std::string name;
    int age;
    Person(const std::string& n, int a) : name(n), age(a) {}
};

// 原始的复杂比较函数
bool complex_person_compare(const Person& a, const Person& b) {
    if (a.age != b.age) {
        return a.age < b.age;
    } else {
        // 这里字符串比较开销相对较大
        return a.name < b.name;
    }
}

// 优化后的比较函数,提前将姓名转换为哈希值
struct PersonHash {
    int age;
    std::size_t nameHash;
    PersonHash(int a, const std::string& n) : age(a), nameHash(std::hash<std::string>{}(n)) {}
};

bool optimized_person_compare(const PersonHash& a, const PersonHash& b) {
    if (a.age != b.age) {
        return a.age < b.age;
    } else {
        return a.nameHash < b.nameHash;
    }
}

int main() {
    std::vector<Person> people = {{"Alice", 25}, {"Bob", 20}, {"Charlie", 25}};
    std::vector<PersonHash> peopleHash;
    for (const auto& p : people) {
        peopleHash.emplace_back(p.age, p.name);
    }

    std::sort(peopleHash.begin(), peopleHash.end(), optimized_person_compare);

    for (const auto& ph : peopleHash) {
        std::cout << "Age: " << ph.age << ", Name Hash: " << ph.nameHash << std::endl;
    }
    return 0;
}

3.4 利用并行算法

在多核处理器环境下,可以利用并行算法来加速排序过程。C++17 引入了并行版本的 sort,即 std::execution::par 策略。通过使用并行算法,sort 可以将数据分成多个部分,在不同的线程上并行进行排序,然后再合并结果。

以下是使用并行 sort 的示例:

#include <iostream>
#include <algorithm>
#include <execution>
#include <vector>

int main() {
    std::vector<int> vec = {3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5};
    std::sort(std::execution::par, vec.begin(), vec.end());
    for (int num : vec) {
        std::cout << num << " ";
    }
    std::cout << std::endl;
    return 0;
}

4. 性能测试与分析

为了评估不同优化策略对 sort 性能的影响,可以进行性能测试。我们可以使用 std::chrono 库来测量排序所需的时间,并对不同规模和分布的数据进行测试。

例如,我们可以编写一个性能测试函数,比较原始 sort、使用三数取中轴点选择的 sort 和并行 sort 的性能:

#include <iostream>
#include <algorithm>
#include <execution>
#include <chrono>
#include <vector>
#include <cstdlib>
#include <ctime>

template <typename RandomAccessIterator>
typename std::iterator_traits<RandomAccessIterator>::value_type
median_of_three(RandomAccessIterator first, RandomAccessIterator last) {
    auto mid = first + std::distance(first, last) / 2;
    if ((*first < *mid) == (*mid < *last)) {
        return *mid;
    } else if ((*last < *mid) == (*mid < *first)) {
        return *mid;
    } else if ((*first < *last) == (*last < *mid)) {
        return *last;
    } else {
        return *first;
    }
}

template <typename RandomAccessIterator>
void my_sort_with_median(RandomAccessIterator first, RandomAccessIterator last) {
    if (first >= last) return;

    auto pivot_value = median_of_three(first, last);
    auto new_pivot = std::partition(first, last, [&](const auto& a) {
        return a < pivot_value;
    });

    my_sort_with_median(first, new_pivot);
    my_sort_with_median(new_pivot, last);
}

void performance_test(int size) {
    std::vector<int> vec(size);
    for (int i = 0; i < size; ++i) {
        vec[i] = rand();
    }

    auto start = std::chrono::high_resolution_clock::now();
    std::sort(vec.begin(), vec.end());
    auto end = std::chrono::high_resolution_clock::now();
    auto std_sort_time = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();

    std::vector<int> vec2 = vec;
    start = std::chrono::high_resolution_clock::now();
    my_sort_with_median(vec2.begin(), vec2.end());
    end = std::chrono::high_resolution_clock::now();
    auto my_sort_time = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();

    std::vector<int> vec3 = vec;
    start = std::chrono::high_resolution_clock::now();
    std::sort(std::execution::par, vec3.begin(), vec3.end());
    end = std::chrono::high_resolution_clock::now();
    auto par_sort_time = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();

    std::cout << "Sorting " << size << " elements:" << std::endl;
    std::cout << "std::sort took " << std_sort_time << " milliseconds." << std::endl;
    std::cout << "my_sort_with_median took " << my_sort_time << " milliseconds." << std::endl;
    std::cout << "parallel std::sort took " << par_sort_time << " milliseconds." << std::endl;
}

int main() {
    std::srand(std::time(nullptr));
    performance_test(10000);
    performance_test(100000);
    performance_test(1000000);
    return 0;
}

通过这样的性能测试,可以直观地看到不同优化策略在不同数据规模下的性能表现,从而根据实际需求选择最合适的优化方案。

5. 内存管理与性能

在排序过程中,内存管理也会对性能产生影响。特别是在处理大数据集时,频繁的内存分配和释放可能导致性能瓶颈。

5.1 避免不必要的内存分配

在使用 sort 时,如果数据类型是自定义类,并且该类在构造和析构时涉及内存分配(例如包含动态分配的成员变量),要注意避免在排序过程中产生不必要的临时对象。

例如,假设我们有一个 BigObject 类,包含一个动态分配的数组:

#include <iostream>
#include <algorithm>
#include <vector>
#include <cstdlib>

class BigObject {
private:
    int* data;
    int size;
public:
    BigObject(int s) : size(s) {
        data = new int[size];
        for (int i = 0; i < size; ++i) {
            data[i] = rand();
        }
    }

    ~BigObject() {
        delete[] data;
    }

    BigObject(const BigObject& other) : size(other.size) {
        data = new int[size];
        for (int i = 0; i < size; ++i) {
            data[i] = other.data[i];
        }
    }

    BigObject& operator=(const BigObject& other) {
        if (this != &other) {
            delete[] data;
            size = other.size;
            data = new int[size];
            for (int i = 0; i < size; ++i) {
                data[i] = other.data[i];
            }
        }
        return *this;
    }

    bool operator<(const BigObject& other) const {
        return size < other.size;
    }
};

int main() {
    std::vector<BigObject> objects;
    objects.emplace_back(100);
    objects.emplace_back(200);
    objects.emplace_back(150);

    std::sort(objects.begin(), objects.end());
    return 0;
}

在上述代码中,BigObject 的拷贝构造函数和赋值运算符会进行大量的内存分配和复制操作。为了优化性能,可以使用移动语义。在 C++11 及以后的版本中,可以通过定义移动构造函数和移动赋值运算符来避免不必要的内存复制。

class BigObject {
private:
    int* data;
    int size;
public:
    BigObject(int s) : size(s) {
        data = new int[size];
        for (int i = 0; i < size; ++i) {
            data[i] = rand();
        }
    }

    ~BigObject() {
        delete[] data;
    }

    BigObject(const BigObject& other) : size(other.size) {
        data = new int[size];
        for (int i = 0; i < size; ++i) {
            data[i] = other.data[i];
        }
    }

    BigObject& operator=(const BigObject& other) {
        if (this != &other) {
            delete[] data;
            size = other.size;
            data = new int[size];
            for (int i = 0; i < size; ++i) {
                data[i] = other.data[i];
            }
        }
        return *this;
    }

    BigObject(BigObject&& other) noexcept : size(other.size), data(other.data) {
        other.size = 0;
        other.data = nullptr;
    }

    BigObject& operator=(BigObject&& other) noexcept {
        if (this != &other) {
            delete[] data;
            size = other.size;
            data = other.data;
            other.size = 0;
            other.data = nullptr;
        }
        return *this;
    }

    bool operator<(const BigObject& other) const {
        return size < other.size;
    }
};

5.2 使用合适的容器

选择合适的容器也会影响排序性能。对于需要频繁插入和删除元素的场景,std::list 可能更合适,但它不支持随机访问,无法直接使用 std::sort。而 std::vectorstd::deque 支持随机访问,适合 std::sortstd::vector 具有连续的内存布局,在缓存命中率方面表现较好,尤其适合大数据集的排序。

例如,以下代码比较了对 std::vectorstd::deque 进行排序的性能:

#include <iostream>
#include <algorithm>
#include <vector>
#include <deque>
#include <chrono>

void test_sort_on_vector(int size) {
    std::vector<int> vec(size);
    for (int i = 0; i < size; ++i) {
        vec[i] = rand();
    }

    auto start = std::chrono::high_resolution_clock::now();
    std::sort(vec.begin(), vec.end());
    auto end = std::chrono::high_resolution_clock::now();

    auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
    std::cout << "Sorting vector of " << size << " elements took " << duration << " milliseconds." << std::endl;
}

void test_sort_on_deque(int size) {
    std::deque<int> deq(size);
    for (int i = 0; i < size; ++i) {
        deq[i] = rand();
    }

    auto start = std::chrono::high_resolution_clock::now();
    std::sort(deq.begin(), deq.end());
    auto end = std::chrono::high_resolution_clock::now();

    auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
    std::cout << "Sorting deque of " << size << " elements took " << duration << " milliseconds." << std::endl;
}

int main() {
    int size = 100000;
    test_sort_on_vector(size);
    test_sort_on_deque(size);
    return 0;
}

6. 与其他排序算法的对比

虽然 std::sort 是一个强大且高效的排序算法,但在某些特定场景下,其他排序算法可能更具优势。

6.1 稳定排序

std::sort 是不稳定的排序算法,即相等元素的相对顺序在排序后可能改变。如果需要保持相等元素的相对顺序,std::stable_sort 是一个更好的选择。std::stable_sort 的时间复杂度在平均和最坏情况下为 (O(n log^2 n)),而 std::sort 的平均时间复杂度为 (O(n log n))。然而,对于某些数据分布和应用场景,稳定性更为重要。

例如,假设我们有一个结构体 Student,包含姓名和成绩,我们希望按成绩排序,并且成绩相同的学生保持原来的顺序:

#include <iostream>
#include <algorithm>
#include <vector>
#include <string>

struct Student {
    std::string name;
    int score;
    Student(const std::string& n, int s) : name(n), score(s) {}
};

bool compare_students(const Student& a, const Student& b) {
    return a.score < b.score;
}

int main() {
    std::vector<Student> students = {{"Alice", 85}, {"Bob", 90}, {"Charlie", 85}, {"David", 95}};
    std::stable_sort(students.begin(), students.end(), compare_students);
    for (const auto& s : students) {
        std::cout << s.name << ": " << s.score << std::endl;
    }
    return 0;
}

6.2 基数排序

基数排序(Radix Sort)是一种非比较排序算法,它根据数字的每一位来进行排序。基数排序在处理整数类型数据且数据范围有限时,具有线性的时间复杂度 (O(n))。与 std::sort 的 (O(n log n)) 平均时间复杂度相比,在特定场景下基数排序可能更快。

以下是一个简单的基数排序实现:

#include <iostream>
#include <vector>

void counting_sort(std::vector<int>& arr, int exp) {
    const int n = arr.size();
    std::vector<int> output(n);
    std::vector<int> count(10, 0);

    for (int i = 0; i < n; ++i) {
        count[(arr[i] / exp) % 10]++;
    }

    for (int i = 1; i < 10; ++i) {
        count[i] += count[i - 1];
    }

    for (int i = n - 1; i >= 0; --i) {
        output[count[(arr[i] / exp) % 10] - 1] = arr[i];
        count[(arr[i] / exp) % 10]--;
    }

    for (int i = 0; i < n; ++i) {
        arr[i] = output[i];
    }
}

void radix_sort(std::vector<int>& arr) {
    int max_val = *std::max_element(arr.begin(), arr.end());
    for (int exp = 1; max_val / exp > 0; exp *= 10) {
        counting_sort(arr, exp);
    }
}

int main() {
    std::vector<int> arr = {170, 45, 75, 90, 802, 24, 2, 66};
    radix_sort(arr);
    for (int num : arr) {
        std::cout << num << " ";
    }
    std::cout << std::endl;
    return 0;
}

通过了解不同排序算法的特点和适用场景,可以在实际编程中根据需求选择最合适的排序方法,以达到最优的性能。

在实际应用中,需要综合考虑数据规模、数据分布、比较函数复杂度、内存管理等多方面因素,对 sort 算法进行优化,以实现高效的排序操作。同时,了解其他排序算法与 std::sort 的对比,也有助于在不同场景下做出更明智的选择。