在看 code 時看到一段神奇的程式碼,也不知道它為何要這樣寫。仔細觀察後發現其實它是想要確認 subset 關係。用 ChatGPT 查一下 subset 怎麼寫後就達到 20 倍加速。
原作者的意圖是確認 A 中元素的 data 是不是都有出現在 B 中,有的話要把這些元素從 B 中刪除。原本做法是用一個迴圈去一一確認 A 中的元素,有找到的話就先暫存在 result 中,並從 B 中移除。途中發現有些 a.data 沒有在 B 中就利用 insert 把剛剛暫存的加回來。
原始程式碼:
void foo(const std::vector<Data> &A, std::vector<int> &B,
std::vector<std::vector<int>> &results) {
int flag = 1;
std::vector<int> result;
for (const auto &a : A) {
auto it = std::find(B.begin(), B.end(), a.data);
if (it == B.end()) {
flag = 0;
B.insert(B.end(), result.begin(), result.end());
break;
}
result.push_back(a.data);
B.erase(std::remove(B.begin(), B.end(), a.data), B.end());
}
results.push_back(result);
std::cout << flag << " ";
}
確認作者意圖後就可以根據原本的用途進行調整,把原本的意圖寫出來。這裡有先使用 transform 將 A 中各元素的 data 提取出來。
根據語義調整:
bool is_subset(const std::vector<int> &A, const std::vector<int> &B) {
std::unordered_set<int> setB(B.begin(), B.end());
for (int a : A) {
if (setB.find(a) == setB.end()) {
return false;
}
}
return true;
}
void opt_foo(const std::vector<Data> &A, std::vector<int> &B,
std::vector<std::vector<int>> &results) {
int flag = 1;
std::vector<int> A_inner(A.size());
std::transform(A.begin(), A.end(), A_inner.begin(),
[](const auto &a) { return a.data; });
if (is_subset(A_inner, B)) {
std::unordered_set<int> tmpSet(A_inner.begin(), A_inner.end());
B.erase(std::remove_if(B.begin(), B.end(),
[&](int x) { return tmpSet.count(x); }),
B.end());
results.push_back(A_inner);
} else {
flag = 0;
}
std::cout << flag << " ";
}
完整程式碼
#include <algorithm>
#include <chrono>
#include <iostream>
#include <random>
#include <unordered_set>
#include <vector>
using namespace std::chrono;
struct Data {
int data;
Data(int data) : data(data) {}
};
void foo(const std::vector<Data> &A, std::vector<int> &B,
std::vector<std::vector<int>> &results) {
int flag = 1;
std::vector<int> result;
for (const auto &a : A) {
auto it = std::find(B.begin(), B.end(), a.data);
if (it == B.end()) {
flag = 0;
B.insert(B.end(), result.begin(), result.end());
break;
}
result.push_back(a.data);
B.erase(std::remove(B.begin(), B.end(), a.data), B.end());
}
results.push_back(result);
std::cout << flag << " ";
}
bool is_subset(const std::vector<int> &A, const std::vector<int> &B) {
std::unordered_set<int> setB(B.begin(), B.end());
for (int a : A) {
if (setB.find(a) == setB.end()) {
return false;
}
}
return true;
}
void opt_foo(const std::vector<Data> &A, std::vector<int> &B,
std::vector<std::vector<int>> &results) {
int flag = 1;
std::vector<int> A_inner(A.size());
std::transform(A.begin(), A.end(), A_inner.begin(),
[](const auto &a) { return a.data; });
if (is_subset(A_inner, B)) {
std::unordered_set<int> tmpSet(A_inner.begin(), A_inner.end());
B.erase(std::remove_if(B.begin(), B.end(),
[&](int x) { return tmpSet.count(x); }),
B.end());
results.push_back(A_inner);
} else {
flag = 0;
}
std::cout << flag << " ";
}
int main() {
// generate random vectors
std::random_device rd;
std::mt19937 gen1(10);
std::mt19937 gen2(10);
int N = 4000;
std::uniform_int_distribution<> dis(1, N);
std::vector<std::vector<int>> results;
auto start = steady_clock::now();
for (int t = 0; t < 50; ++t) {
std::vector<Data> A;
std::vector<int> B;
// Fill A and B with random values
for (int i = 0; i < N; ++i) {
A.emplace_back(i + 1);
}
for (int i = 0; i < 10 * N; ++i) {
B.push_back(dis(gen1));
}
foo(A, B, results);
}
auto end = steady_clock::now();
auto elapsed = duration_cast<milliseconds>(end - start).count();
std::cout << "Elapsed time: " << elapsed << " ms" << std::endl;
start = steady_clock::now();
for (int t = 0; t < 50; ++t) {
std::vector<Data> A;
std::vector<int> B;
// Fill A and B with random values
for (int i = 0; i < N; ++i) {
A.emplace_back(i + 1);
}
for (int i = 0; i < 10 * N; ++i) {
B.push_back(dis(gen2));
}
opt_foo(A, B, results);
}
end = steady_clock::now();
elapsed = duration_cast<milliseconds>(end - start).count();
std::cout << "Elapsed time: " << elapsed << " ms" << std::endl;
return 0;
}
$ g++ -O3 test.cpp -o test.out
$ ./test
1 1 1 1 1 1 0 0 0 0 1 1 1 0 1 1 0 1 1 0 1 0 1 0 1 1 1 1 1 1 1 0 1 1 1 1 1 0 1 0 1 1 1 1 1 1 1 1 1 1
Elapsed time: 1354 ms
1 1 1 1 1 1 0 0 0 0 1 1 1 0 1 1 0 1 1 0 1 0 1 0 1 1 1 1 1 1 1 0 1 1 1 1 1 0 1 0 1 1 1 1 1 1 1 1 1 1
Elapsed time: 65 ms