在看 code 時看到一段神奇的程式碼,也不知道它為何要這樣寫。仔細觀察後發現其實它是想要確認 subset 關係。用 ChatGPT 查一下 subset 怎麼寫後就達到 20 倍加速。

原作者的意圖是確認 A 中元素的 data 是不是都有出現在 B 中,有的話要把這些元素從 B 中刪除。原本做法是用一個迴圈去一一確認 A 中的元素,有找到的話就先暫存在 result 中,並從 B 中移除。途中發現有些 a.data 沒有在 B 中就利用 insert 把剛剛暫存的加回來。

原始程式碼:

void foo(const std::vector<Data> &A, std::vector<int> &B,
         std::vector<std::vector<int>> &results) {
    int flag = 1;
    std::vector<int> result;
    for (const auto &a : A) {
        auto it = std::find(B.begin(), B.end(), a.data);
        if (it == B.end()) {
            flag = 0;
            B.insert(B.end(), result.begin(), result.end());
            break;
        }
        result.push_back(a.data);
        B.erase(std::remove(B.begin(), B.end(), a.data), B.end());
    }
    results.push_back(result);
    std::cout << flag << " ";
}

確認作者意圖後就可以根據原本的用途進行調整,把原本的意圖寫出來。這裡有先使用 transform 將 A 中各元素的 data 提取出來。

根據語義調整:

bool is_subset(const std::vector<int> &A, const std::vector<int> &B) {
    std::unordered_set<int> setB(B.begin(), B.end());
    for (int a : A) {
        if (setB.find(a) == setB.end()) {
            return false;
        }
    }
    return true;
}

void opt_foo(const std::vector<Data> &A, std::vector<int> &B,
             std::vector<std::vector<int>> &results) {
    int flag = 1;
    std::vector<int> A_inner(A.size());
    std::transform(A.begin(), A.end(), A_inner.begin(),
                   [](const auto &a) { return a.data; });
    if (is_subset(A_inner, B)) {
        std::unordered_set<int> tmpSet(A_inner.begin(), A_inner.end());
        B.erase(std::remove_if(B.begin(), B.end(),
                               [&](int x) { return tmpSet.count(x); }),
                B.end());
        results.push_back(A_inner);
    } else {
        flag = 0;
    }
    std::cout << flag << " ";
}
完整程式碼
#include <algorithm>
#include <chrono>
#include <iostream>
#include <random>
#include <unordered_set>
#include <vector>

using namespace std::chrono;

struct Data {
    int data;
    Data(int data) : data(data) {}
};

void foo(const std::vector<Data> &A, std::vector<int> &B,
         std::vector<std::vector<int>> &results) {
    int flag = 1;
    std::vector<int> result;
    for (const auto &a : A) {
        auto it = std::find(B.begin(), B.end(), a.data);
        if (it == B.end()) {
            flag = 0;
            B.insert(B.end(), result.begin(), result.end());
            break;
        }
        result.push_back(a.data);
        B.erase(std::remove(B.begin(), B.end(), a.data), B.end());
    }
    results.push_back(result);
    std::cout << flag << " ";
}

bool is_subset(const std::vector<int> &A, const std::vector<int> &B) {
    std::unordered_set<int> setB(B.begin(), B.end());
    for (int a : A) {
        if (setB.find(a) == setB.end()) {
            return false;
        }
    }
    return true;
}

void opt_foo(const std::vector<Data> &A, std::vector<int> &B,
             std::vector<std::vector<int>> &results) {
    int flag = 1;
    std::vector<int> A_inner(A.size());
    std::transform(A.begin(), A.end(), A_inner.begin(),
                   [](const auto &a) { return a.data; });
    if (is_subset(A_inner, B)) {
        std::unordered_set<int> tmpSet(A_inner.begin(), A_inner.end());
        B.erase(std::remove_if(B.begin(), B.end(),
                               [&](int x) { return tmpSet.count(x); }),
                B.end());
        results.push_back(A_inner);
    } else {
        flag = 0;
    }
    std::cout << flag << " ";
}

int main() {
    // generate random vectors
    std::random_device rd;
    std::mt19937 gen1(10);
    std::mt19937 gen2(10);
    int N = 4000;
    std::uniform_int_distribution<> dis(1, N);

    std::vector<std::vector<int>> results;
    auto start = steady_clock::now();
    for (int t = 0; t < 50; ++t) {
        std::vector<Data> A;
        std::vector<int> B;
        // Fill A and B with random values
        for (int i = 0; i < N; ++i) {
            A.emplace_back(i + 1);
        }
        for (int i = 0; i < 10 * N; ++i) {
            B.push_back(dis(gen1));
        }
        foo(A, B, results);
    }
    auto end = steady_clock::now();
    auto elapsed = duration_cast<milliseconds>(end - start).count();
    std::cout << "Elapsed time: " << elapsed << " ms" << std::endl;

    start = steady_clock::now();
    for (int t = 0; t < 50; ++t) {
        std::vector<Data> A;
        std::vector<int> B;
        // Fill A and B with random values
        for (int i = 0; i < N; ++i) {
            A.emplace_back(i + 1);
        }
        for (int i = 0; i < 10 * N; ++i) {
            B.push_back(dis(gen2));
        }
        opt_foo(A, B, results);
    }
    end = steady_clock::now();
    elapsed = duration_cast<milliseconds>(end - start).count();
    std::cout << "Elapsed time: " << elapsed << " ms" << std::endl;
    return 0;
}
$ g++ -O3 test.cpp -o test.out
$ ./test
1 1 1 1 1 1 0 0 0 0 1 1 1 0 1 1 0 1 1 0 1 0 1 0 1 1 1 1 1 1 1 0 1 1 1 1 1 0 1 0 1 1 1 1 1 1 1 1 1 1 
Elapsed time: 1354 ms
1 1 1 1 1 1 0 0 0 0 1 1 1 0 1 1 0 1 1 0 1 0 1 0 1 1 1 1 1 1 1 0 1 1 1 1 1 0 1 0 1 1 1 1 1 1 1 1 1 1 
Elapsed time: 65 ms