Grok 10.0.0
algo-inl.h
Go to the documentation of this file.
1// Copyright 2021 Google LLC
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16// Normal include guard for target-independent parts
17#ifndef HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_
18#define HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_
19
20#include <stdint.h>
21#include <string.h> // memcpy
22
23#include <algorithm>
24#include <cmath> // std::abs
25#include <vector>
26
27#include "hwy/base.h"
29
30// Third-party algorithms
31#define HAVE_AVX2SORT 0
32#define HAVE_IPS4O 0
33// When enabling, consider changing max_threads (required for Table 1a)
34#define HAVE_PARALLEL_IPS4O (HAVE_IPS4O && 1)
35#define HAVE_PDQSORT 0
36#define HAVE_SORT512 0
37#define HAVE_VXSORT 0
38
39#if HAVE_AVX2SORT
40HWY_PUSH_ATTRIBUTES("avx2,avx")
41#include "avx2sort.h"
43#endif
44#if HAVE_IPS4O || HAVE_PARALLEL_IPS4O
45#include "third_party/ips4o/include/ips4o.hpp"
46#include "third_party/ips4o/include/ips4o/thread_pool.hpp"
47#endif
48#if HAVE_PDQSORT
49#include "third_party/boost/allowed/sort/sort.hpp"
50#endif
51#if HAVE_SORT512
52#include "sort512.h"
53#endif
54
55// vxsort is difficult to compile for multiple targets because it also uses
56// .cpp files, and we'd also have to #undef its include guards. Instead, compile
57// only for AVX2 or AVX3 depending on this macro.
58#define VXSORT_AVX3 1
59#if HAVE_VXSORT
60// inlined from vxsort_targets_enable_avx512 (must close before end of header)
61#ifdef __GNUC__
62#ifdef __clang__
63#if VXSORT_AVX3
64#pragma clang attribute push(__attribute__((target("avx512f,avx512dq"))), \
65 apply_to = any(function))
66#else
67#pragma clang attribute push(__attribute__((target("avx2"))), \
68 apply_to = any(function))
69#endif // VXSORT_AVX3
70
71#else
72#pragma GCC push_options
73#if VXSORT_AVX3
74#pragma GCC target("avx512f,avx512dq")
75#else
76#pragma GCC target("avx2")
77#endif // VXSORT_AVX3
78#endif
79#endif
80
81#if VXSORT_AVX3
82#include "vxsort/machine_traits.avx512.h"
83#else
84#include "vxsort/machine_traits.avx2.h"
85#endif // VXSORT_AVX3
86#include "vxsort/vxsort.h"
87#ifdef __GNUC__
88#ifdef __clang__
89#pragma clang attribute pop
90#else
91#pragma GCC pop_options
92#endif
93#endif
94#endif // HAVE_VXSORT
95
96namespace hwy {
97
99
100static inline std::vector<Dist> AllDist() {
101 return {/*Dist::kUniform8, Dist::kUniform16,*/ Dist::kUniform32};
102}
103
104static inline const char* DistName(Dist dist) {
105 switch (dist) {
106 case Dist::kUniform8:
107 return "uniform8";
108 case Dist::kUniform16:
109 return "uniform16";
110 case Dist::kUniform32:
111 return "uniform32";
112 }
113 return "unreachable";
114}
115
116template <typename T>
118 public:
119 void Notify(T value) {
120 min_ = std::min(min_, value);
121 max_ = std::max(max_, value);
122 // Converting to integer would truncate floats, multiplying to save digits
123 // risks overflow especially when casting, so instead take the sum of the
124 // bit representations as the checksum.
125 uint64_t bits = 0;
126 static_assert(sizeof(T) <= 8, "Expected a built-in type");
127 CopyBytes<sizeof(T)>(&value, &bits);
128 sum_ += bits;
129 count_ += 1;
130 }
131
132 bool operator==(const InputStats& other) const {
133 if (count_ != other.count_) {
134 HWY_ABORT("count %d vs %d\n", static_cast<int>(count_),
135 static_cast<int>(other.count_));
136 }
137
138 if (min_ != other.min_ || max_ != other.max_) {
139 HWY_ABORT("minmax %f/%f vs %f/%f\n", static_cast<double>(min_),
140 static_cast<double>(max_), static_cast<double>(other.min_),
141 static_cast<double>(other.max_));
142 }
143
144 // Sum helps detect duplicated/lost values
145 if (sum_ != other.sum_) {
146 HWY_ABORT("Sum mismatch %g %g; min %g max %g\n",
147 static_cast<double>(sum_), static_cast<double>(other.sum_),
148 static_cast<double>(min_), static_cast<double>(max_));
149 }
150
151 return true;
152 }
153
154 private:
155 T min_ = hwy::HighestValue<T>();
156 T max_ = hwy::LowestValue<T>();
157 uint64_t sum_ = 0;
158 size_t count_ = 0;
159};
160
161enum class Algo {
162#if HAVE_AVX2SORT
163 kSEA,
164#endif
165#if HAVE_IPS4O
166 kIPS4O,
167#endif
168#if HAVE_PARALLEL_IPS4O
169 kParallelIPS4O,
170#endif
171#if HAVE_PDQSORT
172 kPDQ,
173#endif
174#if HAVE_SORT512
175 kSort512,
176#endif
177#if HAVE_VXSORT
178 kVXSort,
179#endif
180 kStd,
181 kVQSort,
182 kHeap,
183};
184
185const char* AlgoName(Algo algo) {
186 switch (algo) {
187#if HAVE_AVX2SORT
188 case Algo::kSEA:
189 return "sea";
190#endif
191#if HAVE_IPS4O
192 case Algo::kIPS4O:
193 return "ips4o";
194#endif
195#if HAVE_PARALLEL_IPS4O
196 case Algo::kParallelIPS4O:
197 return "par_ips4o";
198#endif
199#if HAVE_PDQSORT
200 case Algo::kPDQ:
201 return "pdq";
202#endif
203#if HAVE_SORT512
204 case Algo::kSort512:
205 return "sort512";
206#endif
207#if HAVE_VXSORT
208 case Algo::kVXSort:
209 return "vxsort";
210#endif
211 case Algo::kStd:
212 return "std";
213 case Algo::kVQSort:
214 return "vq";
215 case Algo::kHeap:
216 return "heap";
217 }
218 return "unreachable";
219}
220
221} // namespace hwy
222#endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_
223
224// Per-target
225#if defined(HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE) == \
226 defined(HWY_TARGET_TOGGLE)
227#ifdef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE
228#undef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE
229#else
230#define HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE
231#endif
232
235#include "hwy/contrib/sort/vqsort-inl.h" // HeapSort
237
239namespace hwy {
240namespace HWY_NAMESPACE {
241
243 static HWY_INLINE uint64_t SplitMix64(uint64_t z) {
244 z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ull;
245 z = (z ^ (z >> 27)) * 0x94D049BB133111EBull;
246 return z ^ (z >> 31);
247 }
248
249 public:
250 // Generates two vectors of 64-bit seeds via SplitMix64 and stores into
251 // `seeds`. Generating these afresh in each ChoosePivot is too expensive.
252 template <class DU64>
253 static void GenerateSeeds(DU64 du64, TFromD<DU64>* HWY_RESTRICT seeds) {
254 seeds[0] = SplitMix64(0x9E3779B97F4A7C15ull);
255 for (size_t i = 1; i < 2 * Lanes(du64); ++i) {
256 seeds[i] = SplitMix64(seeds[i - 1]);
257 }
258 }
259
260 // Need to pass in the state because vector cannot be class members.
261 template <class DU64>
262 static Vec<DU64> RandomBits(DU64 /* tag */, Vec<DU64>& state0,
263 Vec<DU64>& state1) {
264 Vec<DU64> s1 = state0;
265 Vec<DU64> s0 = state1;
266 const Vec<DU64> bits = Add(s1, s0);
267 state0 = s0;
268 s1 = Xor(s1, ShiftLeft<23>(s1));
269 state1 = Xor(s1, Xor(s0, Xor(ShiftRight<18>(s1), ShiftRight<5>(s0))));
270 return bits;
271 }
272};
273
274template <typename T, class DU64, HWY_IF_NOT_FLOAT(T)>
276 const Vec<DU64> mask) {
277 const Vec<DU64> bits = Xorshift128Plus::RandomBits(du64, s0, s1);
278 return And(bits, mask);
279}
280
281// Important to avoid denormals, which are flushed to zero by SIMD but not
282// scalar sorts, and NaN, which may be ordered differently in scalar vs. SIMD.
283template <typename T, class DU64, HWY_IF_FLOAT(T)>
284Vec<DU64> RandomValues(DU64 du64, Vec<DU64>& s0, Vec<DU64>& s1,
285 const Vec<DU64> mask) {
286 const Vec<DU64> bits = Xorshift128Plus::RandomBits(du64, s0, s1);
287 const Vec<DU64> values = And(bits, mask);
288#if HWY_TARGET == HWY_SCALAR // Cannot repartition u64 to i32
289 const RebindToSigned<DU64> di;
290#else
291 const Repartition<MakeSigned<T>, DU64> di;
292#endif
293 const RebindToFloat<decltype(di)> df;
294 const RebindToUnsigned<decltype(di)> du;
295 const auto k1 = BitCast(du64, Set(df, T{1.0}));
296 const auto mantissa = BitCast(du64, Set(du, MantissaMask<T>()));
297 // Avoid NaN/denormal by converting from (range-limited) integer.
298 const Vec<DU64> no_nan = OrAnd(k1, values, mantissa);
299 return BitCast(du64, ConvertTo(df, BitCast(di, no_nan)));
300}
301
302template <class DU64>
303Vec<DU64> MaskForDist(DU64 du64, const Dist dist, size_t sizeof_t) {
304 switch (sizeof_t) {
305 case 2:
306 return Set(du64, (dist == Dist::kUniform8) ? 0x00FF00FF00FF00FFull
307 : 0xFFFFFFFFFFFFFFFFull);
308 case 4:
309 return Set(du64, (dist == Dist::kUniform8) ? 0x000000FF000000FFull
310 : (dist == Dist::kUniform16) ? 0x0000FFFF0000FFFFull
311 : 0xFFFFFFFFFFFFFFFFull);
312 case 8:
313 return Set(du64, (dist == Dist::kUniform8) ? 0x00000000000000FFull
314 : (dist == Dist::kUniform16) ? 0x000000000000FFFFull
315 : 0x00000000FFFFFFFFull);
316 default:
317 HWY_ABORT("Logic error");
318 return Zero(du64);
319 }
320}
321
322template <typename T>
323InputStats<T> GenerateInput(const Dist dist, T* v, size_t num) {
325 using VU64 = Vec<decltype(du64)>;
326 const size_t N64 = Lanes(du64);
327 auto buf = hwy::AllocateAligned<uint64_t>(2 * N64);
328 Xorshift128Plus::GenerateSeeds(du64, buf.get());
329 auto s0 = Load(du64, buf.get());
330 auto s1 = Load(du64, buf.get() + N64);
331
332 const VU64 mask = MaskForDist(du64, dist, sizeof(T));
333
334 const Repartition<T, decltype(du64)> d;
335 const size_t N = Lanes(d);
336 size_t i = 0;
337 for (; i + N <= num; i += N) {
338 const VU64 bits = RandomValues<T>(du64, s0, s1, mask);
339#if HWY_ARCH_RVV || (HWY_TARGET == HWY_NEON && HWY_ARCH_ARM_V7)
340 // v may not be 64-bit aligned
341 StoreU(bits, du64, buf.get());
342 memcpy(v + i, buf.get(), N64 * sizeof(uint64_t));
343#else
344 StoreU(bits, du64, reinterpret_cast<uint64_t*>(v + i));
345#endif
346 }
347 if (i < num) {
348 const VU64 bits = RandomValues<T>(du64, s0, s1, mask);
349 StoreU(bits, du64, buf.get());
350 memcpy(v + i, buf.get(), (num - i) * sizeof(T));
351 }
352
353 InputStats<T> input_stats;
354 for (size_t i = 0; i < num; ++i) {
355 input_stats.Notify(v[i]);
356 }
357 return input_stats;
358}
359
362};
363
365#if HAVE_PARALLEL_IPS4O
366 const unsigned max_threads = hwy::LimitsMax<unsigned>(); // 16 for Table 1a
367 ips4o::StdThreadPool pool{static_cast<int>(
368 HWY_MIN(max_threads, std::thread::hardware_concurrency() / 2))};
369#endif
370 std::vector<ThreadLocal> tls{1};
371};
372
373// Bridge from keys (passed to Run) to lanes as expected by HeapSort. For
374// non-128-bit keys they are the same:
375template <class Order, typename KeyType, HWY_IF_NOT_LANE_SIZE(KeyType, 16)>
376void CallHeapSort(KeyType* HWY_RESTRICT keys, const size_t num_keys) {
377 using detail::TraitsLane;
379 if (Order().IsAscending()) {
380 const SharedTraits<TraitsLane<detail::OrderAscending<KeyType>>> st;
381 return detail::HeapSort(st, keys, num_keys);
382 } else {
383 const SharedTraits<TraitsLane<detail::OrderDescending<KeyType>>> st;
384 return detail::HeapSort(st, keys, num_keys);
385 }
386}
387
388#if VQSORT_ENABLED
389template <class Order>
390void CallHeapSort(hwy::uint128_t* HWY_RESTRICT keys, const size_t num_keys) {
391 using detail::SharedTraits;
392 using detail::Traits128;
393 uint64_t* lanes = reinterpret_cast<uint64_t*>(keys);
394 const size_t num_lanes = num_keys * 2;
395 if (Order().IsAscending()) {
396 const SharedTraits<Traits128<detail::OrderAscending128>> st;
397 return detail::HeapSort(st, lanes, num_lanes);
398 } else {
399 const SharedTraits<Traits128<detail::OrderDescending128>> st;
400 return detail::HeapSort(st, lanes, num_lanes);
401 }
402}
403
404template <class Order>
405void CallHeapSort(K64V64* HWY_RESTRICT keys, const size_t num_keys) {
406 using detail::SharedTraits;
407 using detail::Traits128;
408 uint64_t* lanes = reinterpret_cast<uint64_t*>(keys);
409 const size_t num_lanes = num_keys * 2;
410 if (Order().IsAscending()) {
411 const SharedTraits<Traits128<detail::OrderAscendingKV128>> st;
412 return detail::HeapSort(st, lanes, num_lanes);
413 } else {
414 const SharedTraits<Traits128<detail::OrderDescendingKV128>> st;
415 return detail::HeapSort(st, lanes, num_lanes);
416 }
417}
418#endif // VQSORT_ENABLED
419
420template <class Order, typename KeyType>
421void Run(Algo algo, KeyType* HWY_RESTRICT inout, size_t num,
422 SharedState& shared, size_t thread) {
423 const std::less<KeyType> less;
424 const std::greater<KeyType> greater;
425
426 switch (algo) {
427#if HAVE_AVX2SORT
428 case Algo::kSEA:
429 return avx2::quicksort(inout, static_cast<int>(num));
430#endif
431
432#if HAVE_IPS4O
433 case Algo::kIPS4O:
434 if (Order().IsAscending()) {
435 return ips4o::sort(inout, inout + num, less);
436 } else {
437 return ips4o::sort(inout, inout + num, greater);
438 }
439#endif
440
441#if HAVE_PARALLEL_IPS4O
442 case Algo::kParallelIPS4O:
443 if (Order().IsAscending()) {
444 return ips4o::parallel::sort(inout, inout + num, less, shared.pool);
445 } else {
446 return ips4o::parallel::sort(inout, inout + num, greater, shared.pool);
447 }
448#endif
449
450#if HAVE_SORT512
451 case Algo::kSort512:
452 HWY_ABORT("not supported");
453 // return Sort512::Sort(inout, num);
454#endif
455
456#if HAVE_PDQSORT
457 case Algo::kPDQ:
458 if (Order().IsAscending()) {
459 return boost::sort::pdqsort_branchless(inout, inout + num, less);
460 } else {
461 return boost::sort::pdqsort_branchless(inout, inout + num, greater);
462 }
463#endif
464
465#if HAVE_VXSORT
466 case Algo::kVXSort: {
467#if (VXSORT_AVX3 && HWY_TARGET != HWY_AVX3) || \
468 (!VXSORT_AVX3 && HWY_TARGET != HWY_AVX2)
469 fprintf(stderr, "Do not call for target %s\n",
471 return;
472#else
473#if VXSORT_AVX3
474 vxsort::vxsort<KeyType, vxsort::AVX512> vx;
475#else
476 vxsort::vxsort<KeyType, vxsort::AVX2> vx;
477#endif
478 if (Order().IsAscending()) {
479 return vx.sort(inout, inout + num - 1);
480 } else {
481 fprintf(stderr, "Skipping VX - does not support descending order\n");
482 return;
483 }
484#endif // enabled for this target
485 }
486#endif // HAVE_VXSORT
487
488 case Algo::kStd:
489 if (Order().IsAscending()) {
490 return std::sort(inout, inout + num, less);
491 } else {
492 return std::sort(inout, inout + num, greater);
493 }
494
495 case Algo::kVQSort:
496 return shared.tls[thread].sorter(inout, num, Order());
497
498 case Algo::kHeap:
499 return CallHeapSort<Order>(inout, num);
500
501 default:
502 HWY_ABORT("Not implemented");
503 }
504}
505
506// NOLINTNEXTLINE(google-readability-namespace-comments)
507} // namespace HWY_NAMESPACE
508} // namespace hwy
510
511#endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE
HWY_AFTER_NAMESPACE()
HWY_BEFORE_NAMESPACE()
#define HWY_RESTRICT
Definition: base.h:61
#define HWY_POP_ATTRIBUTES
Definition: base.h:114
#define HWY_MIN(a, b)
Definition: base.h:125
#define HWY_ABORT(format,...)
Definition: base.h:141
#define HWY_INLINE
Definition: base.h:62
#define HWY_PUSH_ATTRIBUTES(targets_str)
Definition: base.h:113
Definition: algo-inl.h:242
static Vec< DU64 > RandomBits(DU64, Vec< DU64 > &state0, Vec< DU64 > &state1)
Definition: algo-inl.h:262
static void GenerateSeeds(DU64 du64, TFromD< DU64 > *HWY_RESTRICT seeds)
Definition: algo-inl.h:253
static HWY_INLINE uint64_t SplitMix64(uint64_t z)
Definition: algo-inl.h:243
Definition: algo-inl.h:117
T min_
Definition: algo-inl.h:155
size_t count_
Definition: algo-inl.h:158
T max_
Definition: algo-inl.h:156
bool operator==(const InputStats &other) const
Definition: algo-inl.h:132
void Notify(T value)
Definition: algo-inl.h:119
uint64_t sum_
Definition: algo-inl.h:157
Definition: vqsort.h:41
#define HWY_TARGET
Definition: detect_targets.h:341
void HeapSort(Traits st, T *HWY_RESTRICT lanes, const size_t num_lanes)
Definition: vqsort-inl.h:92
d
Definition: rvv-inl.h:1742
InputStats< T > GenerateInput(const Dist dist, T *v, size_t num)
Definition: algo-inl.h:323
void CallHeapSort(KeyType *HWY_RESTRICT keys, const size_t num_keys)
Definition: algo-inl.h:376
void Run(Algo algo, KeyType *HWY_RESTRICT inout, size_t num, SharedState &shared, size_t thread)
Definition: algo-inl.h:421
HWY_API Vec128< T, N > And(const Vec128< T, N > a, const Vec128< T, N > b)
Definition: arm_neon-inl.h:1934
Rebind< MakeUnsigned< TFromD< D > >, D > RebindToUnsigned
Definition: ops/shared-inl.h:200
Rebind< MakeFloat< TFromD< D > >, D > RebindToFloat
Definition: ops/shared-inl.h:202
HWY_API V Add(V a, V b)
Definition: arm_neon-inl.h:6274
HWY_API constexpr size_t Lanes(Simd< T, N, kPow2 >)
Definition: arm_sve-inl.h:236
HWY_API Vec128< T, N > Load(Simd< T, N, 0 > d, const T *HWY_RESTRICT p)
Definition: arm_neon-inl.h:2706
Vec< DU64 > MaskForDist(DU64 du64, const Dist dist, size_t sizeof_t)
Definition: algo-inl.h:303
HWY_API Vec128< T, N > Xor(const Vec128< T, N > a, const Vec128< T, N > b)
Definition: arm_neon-inl.h:1983
HWY_API void StoreU(const Vec128< uint8_t > v, Full128< uint8_t >, uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2725
svuint16_t Set(Simd< bfloat16_t, N, kPow2 > d, bfloat16_t arg)
Definition: arm_sve-inl.h:312
HWY_API Vec128< T, N > OrAnd(Vec128< T, N > o, Vec128< T, N > a1, Vec128< T, N > a2)
Definition: arm_neon-inl.h:1999
HWY_API Vec128< T, N > BitCast(Simd< T, N, 0 > d, Vec128< FromT, N *sizeof(T)/sizeof(FromT)> v)
Definition: arm_neon-inl.h:988
HWY_API Vec128< T, N > Zero(Simd< T, N, 0 > d)
Definition: arm_neon-inl.h:1011
typename D::template Repartition< T > Repartition
Definition: ops/shared-inl.h:206
HWY_API Vec128< float > ConvertTo(Full128< float >, const Vec128< int32_t > v)
Definition: arm_neon-inl.h:3273
N
Definition: rvv-inl.h:1742
ScalableTag< T, -1 > SortTag
Definition: contrib/sort/shared-inl.h:123
Vec< DU64 > RandomValues(DU64 du64, Vec< DU64 > &s0, Vec< DU64 > &s1, const Vec< DU64 > mask)
Definition: algo-inl.h:275
const vfloat64m1_t v
Definition: rvv-inl.h:1742
typename D::T TFromD
Definition: ops/shared-inl.h:191
decltype(Zero(D())) Vec
Definition: generic_ops-inl.h:32
Definition: aligned_allocator.h:27
static const char * DistName(Dist dist)
Definition: algo-inl.h:104
static HWY_MAYBE_UNUSED const char * TargetName(uint32_t target)
Definition: targets.h:77
Dist
Definition: algo-inl.h:98
static std::vector< Dist > AllDist()
Definition: algo-inl.h:100
const char * AlgoName(Algo algo)
Definition: algo-inl.h:185
Algo
Definition: algo-inl.h:161
#define HWY_NAMESPACE
Definition: set_macros-inl.h:82
Definition: algo-inl.h:364
std::vector< ThreadLocal > tls
Definition: algo-inl.h:370
Definition: algo-inl.h:360
Sorter sorter
Definition: algo-inl.h:361
Definition: sorting_networks-inl.h:686
Definition: traits-inl.h:381
Definition: base.h:264