Grok  9.7.5
vqsort-inl.h
Go to the documentation of this file.
1 // Copyright 2021 Google LLC
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 // Normal include guard for target-independent parts
17 #ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
18 #define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
19 
20 // Makes it harder for adversaries to predict our sampling locations, at the
21 // cost of 1-2% increased runtime.
22 #ifndef VQSORT_SECURE_RNG
23 #define VQSORT_SECURE_RNG 0
24 #endif
25 
26 #if VQSORT_SECURE_RNG
27 #include "third_party/absl/random/random.h"
28 #endif
29 
30 #include <string.h> // memcpy
31 
32 #include "hwy/cache_control.h" // Prefetch
34 #include "hwy/contrib/sort/vqsort.h" // Fill24Bytes
35 
36 #if HWY_IS_MSAN
37 #include <sanitizer/msan_interface.h>
38 #endif
39 
40 #endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
41 
42 // Per-target
43 #if defined(HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE) == \
44  defined(HWY_TARGET_TOGGLE)
45 #ifdef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
46 #undef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
47 #else
48 #define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
49 #endif
50 
53 #include "hwy/highway.h"
54 
56 namespace hwy {
57 namespace HWY_NAMESPACE {
58 namespace detail {
59 
60 #if HWY_TARGET == HWY_SCALAR
61 
62 template <typename T>
63 void Swap(T* a, T* b) {
64  T t = *a;
65  *a = *b;
66  *b = t;
67 }
68 
69 // Scalar version of HeapSort (see below)
70 template <class Traits, typename T>
71 void HeapSort(Traits st, T* HWY_RESTRICT keys, const size_t num) {
72  if (num < 2) return;
73 
74  // Build heap.
75  for (size_t i = 1; i < num; i += 1) {
76  size_t j = i;
77  while (j != 0) {
78  const size_t idx_parent = ((j - 1) / 1 / 2);
79  if (!st.Compare1(keys + idx_parent, keys + j)) {
80  break;
81  }
82  Swap(keys + j, keys + idx_parent);
83  j = idx_parent;
84  }
85  }
86 
87  for (size_t i = num - 1; i != 0; i -= 1) {
88  // Swap root with last
89  Swap(keys + 0, keys + i);
90 
91  // Sift down the new root.
92  size_t j = 0;
93  while (j < i) {
94  const size_t left = 2 * j + 1;
95  const size_t right = 2 * j + 2;
96  if (left >= i) break;
97  size_t idx_larger = j;
98  if (st.Compare1(keys + j, keys + left)) {
99  idx_larger = left;
100  }
101  if (right < i && st.Compare1(keys + idx_larger, keys + right)) {
102  idx_larger = right;
103  }
104  if (idx_larger == j) break;
105  Swap(keys + j, keys + idx_larger);
106  j = idx_larger;
107  }
108  }
109 }
110 
111 #else
112 
114 
115 // ------------------------------ HeapSort
116 
117 // Heapsort: O(1) space, O(N*logN) worst-case comparisons.
118 // Based on LLVM sanitizer_common.h, licensed under Apache-2.0.
119 template <class Traits, typename T>
120 void HeapSort(Traits st, T* HWY_RESTRICT keys, const size_t num) {
121  constexpr size_t N1 = st.LanesPerKey();
122  const FixedTag<T, N1> d;
123 
124  if (num < 2 * N1) return;
125 
126  // Build heap.
127  for (size_t i = N1; i < num; i += N1) {
128  size_t j = i;
129  while (j != 0) {
130  const size_t idx_parent = ((j - N1) / N1 / 2) * N1;
131  if (AllFalse(d, st.Compare(d, st.SetKey(d, keys + idx_parent),
132  st.SetKey(d, keys + j)))) {
133  break;
134  }
135  st.Swap(keys + j, keys + idx_parent);
136  j = idx_parent;
137  }
138  }
139 
140  for (size_t i = num - N1; i != 0; i -= N1) {
141  // Swap root with last
142  st.Swap(keys + 0, keys + i);
143 
144  // Sift down the new root.
145  size_t j = 0;
146  while (j < i) {
147  const size_t left = 2 * j + N1;
148  const size_t right = 2 * j + 2 * N1;
149  if (left >= i) break;
150  size_t idx_larger = j;
151  const auto key_j = st.SetKey(d, keys + j);
152  if (AllTrue(d, st.Compare(d, key_j, st.SetKey(d, keys + left)))) {
153  idx_larger = left;
154  }
155  if (right < i && AllTrue(d, st.Compare(d, st.SetKey(d, keys + idx_larger),
156  st.SetKey(d, keys + right)))) {
157  idx_larger = right;
158  }
159  if (idx_larger == j) break;
160  st.Swap(keys + j, keys + idx_larger);
161  j = idx_larger;
162  }
163  }
164 }
165 
166 // ------------------------------ BaseCase
167 
168 // Sorts `keys` within the range [0, num) via sorting network.
169 template <class D, class Traits, typename T>
170 HWY_NOINLINE void BaseCase(D d, Traits st, T* HWY_RESTRICT keys, size_t num,
171  T* HWY_RESTRICT buf) {
172  const size_t N = Lanes(d);
173  using V = decltype(Zero(d));
174 
175  // _Nonzero32 requires num - 1 != 0.
176  if (HWY_UNLIKELY(num <= 1)) return;
177 
178  // Reshape into a matrix with kMaxRows rows, and columns limited by the
179  // 1D `num`, which is upper-bounded by the vector width (see BaseCaseNum).
180  const size_t num_pow2 = size_t{1}
182  static_cast<uint32_t>(num - 1)));
183  HWY_DASSERT(num <= num_pow2 && num_pow2 <= Constants::BaseCaseNum(N));
184  const size_t cols =
185  HWY_MAX(st.LanesPerKey(), num_pow2 >> Constants::kMaxRowsLog2);
186  HWY_DASSERT(cols <= N);
187 
188  // Copy `keys` to `buf`.
189  size_t i;
190  for (i = 0; i + N <= num; i += N) {
191  Store(LoadU(d, keys + i), d, buf + i);
192  }
193  SafeCopyN(num - i, d, keys + i, buf + i);
194  i = num;
195 
196  // Fill with padding - last in sort order, not copied to keys.
197  const V kPadding = st.LastValue(d);
198  // Initialize an extra vector because SortingNetwork loads full vectors,
199  // which may exceed cols*kMaxRows.
200  for (; i < (cols * Constants::kMaxRows + N); i += N) {
201  StoreU(kPadding, d, buf + i);
202  }
203 
204  SortingNetwork(st, buf, cols);
205 
206  for (i = 0; i + N <= num; i += N) {
207  StoreU(Load(d, buf + i), d, keys + i);
208  }
209  SafeCopyN(num - i, d, buf + i, keys + i);
210 }
211 
212 // ------------------------------ Partition
213 
214 // Consumes from `left` until a multiple of kUnroll*N remains.
215 // Temporarily stores the right side into `buf`, then moves behind `right`.
216 template <class D, class Traits, class T>
217 HWY_NOINLINE void PartitionToMultipleOfUnroll(D d, Traits st,
218  T* HWY_RESTRICT keys,
219  size_t& left, size_t& right,
220  const Vec<D> pivot,
221  T* HWY_RESTRICT buf) {
222  constexpr size_t kUnroll = Constants::kPartitionUnroll;
223  const size_t N = Lanes(d);
224  size_t readL = left;
225  size_t bufR = 0;
226  const size_t num = right - left;
227  // Partition requires both a multiple of kUnroll*N and at least
228  // 2*kUnroll*N for the initial loads. If less, consume all here.
229  const size_t num_rem =
230  (num < 2 * kUnroll * N) ? num : (num & (kUnroll * N - 1));
231  size_t i = 0;
232  for (; i + N <= num_rem; i += N) {
233  const Vec<D> vL = LoadU(d, keys + readL);
234  readL += N;
235 
236  const auto comp = st.Compare(d, pivot, vL);
237  left += CompressBlendedStore(vL, Not(comp), d, keys + left);
238  bufR += CompressStore(vL, comp, d, buf + bufR);
239  }
240  // Last iteration: only use valid lanes.
241  if (HWY_LIKELY(i != num_rem)) {
242  const auto mask = FirstN(d, num_rem - i);
243  const Vec<D> vL = LoadU(d, keys + readL);
244 
245  const auto comp = st.Compare(d, pivot, vL);
246  left += CompressBlendedStore(vL, AndNot(comp, mask), d, keys + left);
247  bufR += CompressStore(vL, And(comp, mask), d, buf + bufR);
248  }
249 
250  // MSAN seems not to understand CompressStore. buf[0, bufR) are valid.
251 #if HWY_IS_MSAN
252  __msan_unpoison(buf, bufR * sizeof(T));
253 #endif
254 
255  // Everything we loaded was put into buf, or behind the new `left`, after
256  // which there is space for bufR items. First move items from `right` to
257  // `left` to free up space, then copy `buf` into the vacated `right`.
258  // A loop with masked loads from `buf` is insufficient - we would also need to
259  // mask from `right`. Combining a loop with memcpy for the remainders is
260  // slower than just memcpy, so we use that for simplicity.
261  right -= bufR;
262  memcpy(keys + left, keys + right, bufR * sizeof(T));
263  memcpy(keys + right, buf, bufR * sizeof(T));
264 }
265 
266 template <class D, class Traits, typename T>
267 HWY_INLINE void StoreLeftRight(D d, Traits st, const Vec<D> v,
268  const Vec<D> pivot, T* HWY_RESTRICT keys,
269  size_t& writeL, size_t& writeR) {
270  const size_t N = Lanes(d);
271 
272  const auto comp = st.Compare(d, pivot, v);
273 
275  // Non-native Compress (e.g. AVX2): we are able to partition a vector using
276  // a single Compress+two StoreU instead of two Compress[Blended]Store. The
277  // latter are more expensive. Because we store entire vectors, the contents
278  // between the updated writeL and writeR are ignored and will be overwritten
279  // by subsequent calls. This works because writeL and writeR are at least
280  // two vectors apart.
281  const auto mask = Not(comp);
282  const auto lr = Compress(v, mask);
283  const size_t num_left = CountTrue(d, mask);
284  StoreU(lr, d, keys + writeL);
285  writeL += num_left;
286  // Now write the right-side elements (if any), such that the previous writeR
287  // is one past the end of the newly written right elements, then advance.
288  StoreU(lr, d, keys + writeR - N);
289  writeR -= (N - num_left);
290  } else {
291  // Native Compress[Store] (e.g. AVX3), which only keep the left or right
292  // side, not both, hence we require two calls.
293  const size_t num_left = CompressStore(v, Not(comp), d, keys + writeL);
294  writeL += num_left;
295 
296  writeR -= (N - num_left);
297  (void)CompressBlendedStore(v, comp, d, keys + writeR);
298  }
299 }
300 
301 template <class D, class Traits, typename T>
302 HWY_INLINE void StoreLeftRight4(D d, Traits st, const Vec<D> v0,
303  const Vec<D> v1, const Vec<D> v2,
304  const Vec<D> v3, const Vec<D> pivot,
305  T* HWY_RESTRICT keys, size_t& writeL,
306  size_t& writeR) {
307  StoreLeftRight(d, st, v0, pivot, keys, writeL, writeR);
308  StoreLeftRight(d, st, v1, pivot, keys, writeL, writeR);
309  StoreLeftRight(d, st, v2, pivot, keys, writeL, writeR);
310  StoreLeftRight(d, st, v3, pivot, keys, writeL, writeR);
311 }
312 
313 // Moves "<= pivot" keys to the front, and others to the back. pivot is
314 // broadcasted. Time-critical!
315 //
316 // Aligned loads do not seem to be worthwhile (not bottlenecked by load ports).
317 template <class D, class Traits, typename T>
318 HWY_NOINLINE size_t Partition(D d, Traits st, T* HWY_RESTRICT keys, size_t left,
319  size_t right, const Vec<D> pivot,
320  T* HWY_RESTRICT buf) {
321  using V = decltype(Zero(d));
322  const size_t N = Lanes(d);
323 
324  // StoreLeftRight will CompressBlendedStore ending at `writeR`. Unless all
325  // lanes happen to be in the right-side partition, this will overrun `keys`,
326  // which triggers asan errors. Avoid by special-casing the last vector.
327  HWY_DASSERT(right - left > 2 * N); // ensured by HandleSpecialCases
328  right -= N;
329  const size_t last = right;
330  const V vlast = LoadU(d, keys + last);
331 
332  PartitionToMultipleOfUnroll(d, st, keys, left, right, pivot, buf);
333  constexpr size_t kUnroll = Constants::kPartitionUnroll;
334 
335  // Invariant: [left, writeL) and [writeR, right) are already partitioned.
336  size_t writeL = left;
337  size_t writeR = right;
338 
339  const size_t num = right - left;
340  // Cannot load if there were fewer than 2 * kUnroll * N.
341  if (HWY_LIKELY(num != 0)) {
342  HWY_DASSERT(num >= 2 * kUnroll * N);
343  HWY_DASSERT((num & (kUnroll * N - 1)) == 0);
344 
345  // Make space for writing in-place by reading from left and right.
346  const V vL0 = LoadU(d, keys + left + 0 * N);
347  const V vL1 = LoadU(d, keys + left + 1 * N);
348  const V vL2 = LoadU(d, keys + left + 2 * N);
349  const V vL3 = LoadU(d, keys + left + 3 * N);
350  left += kUnroll * N;
351  right -= kUnroll * N;
352  const V vR0 = LoadU(d, keys + right + 0 * N);
353  const V vR1 = LoadU(d, keys + right + 1 * N);
354  const V vR2 = LoadU(d, keys + right + 2 * N);
355  const V vR3 = LoadU(d, keys + right + 3 * N);
356 
357  // The left/right updates may consume all inputs, so check before the loop.
358  while (left != right) {
359  V v0, v1, v2, v3;
360 
361  // Free up capacity for writing by loading from the side that has less.
362  // Data-dependent but branching is faster than forcing branch-free.
363  const size_t capacityL = left - writeL;
364  const size_t capacityR = writeR - right;
365  HWY_DASSERT(capacityL <= num && capacityR <= num); // >= 0
366  if (capacityR < capacityL) {
367  right -= kUnroll * N;
368  v0 = LoadU(d, keys + right + 0 * N);
369  v1 = LoadU(d, keys + right + 1 * N);
370  v2 = LoadU(d, keys + right + 2 * N);
371  v3 = LoadU(d, keys + right + 3 * N);
372  hwy::Prefetch(keys + right - 3 * kUnroll * N);
373  } else {
374  v0 = LoadU(d, keys + left + 0 * N);
375  v1 = LoadU(d, keys + left + 1 * N);
376  v2 = LoadU(d, keys + left + 2 * N);
377  v3 = LoadU(d, keys + left + 3 * N);
378  left += kUnroll * N;
379  hwy::Prefetch(keys + left + 3 * kUnroll * N);
380  }
381 
382  StoreLeftRight4(d, st, v0, v1, v2, v3, pivot, keys, writeL, writeR);
383  }
384 
385  // Now finish writing the initial left/right to the middle.
386  StoreLeftRight4(d, st, vL0, vL1, vL2, vL3, pivot, keys, writeL, writeR);
387  StoreLeftRight4(d, st, vR0, vR1, vR2, vR3, pivot, keys, writeL, writeR);
388  }
389 
390  // We have partitioned [left, right) such that writeL is the boundary.
391  HWY_DASSERT(writeL == writeR);
392  // Make space for inserting vlast: move up to N of the first right-side keys
393  // into the unused space starting at last. If we have fewer, ensure they are
394  // the last items in that vector by subtracting from the *load* address,
395  // which is safe because we have at least two vectors (checked above).
396  const size_t totalR = last - writeL;
397  const size_t startR = totalR < N ? writeL + totalR - N : writeL;
398  StoreU(LoadU(d, keys + startR), d, keys + last);
399 
400  // Partition vlast: write L, then R, into the single-vector gap at writeL.
401  const auto comp = st.Compare(d, pivot, vlast);
402  writeL += CompressBlendedStore(vlast, Not(comp), d, keys + writeL);
403  (void)CompressBlendedStore(vlast, comp, d, keys + writeL);
404 
405  return writeL;
406 }
407 
408 // ------------------------------ Pivot
409 
410 template <class Traits, class V>
411 HWY_INLINE V MedianOf3(Traits st, V v0, V v1, V v2) {
412  const DFromV<V> d;
413  // Slightly faster for 128-bit, apparently because not serially dependent.
414  if (st.Is128()) {
415  // Median = XOR-sum 'minus' the first and last. Calling First twice is
416  // slightly faster than Compare + 2 IfThenElse or even IfThenElse + XOR.
417  const auto sum = Xor(Xor(v0, v1), v2);
418  const auto first = st.First(d, st.First(d, v0, v1), v2);
419  const auto last = st.Last(d, st.Last(d, v0, v1), v2);
420  return Xor(Xor(sum, first), last);
421  }
422  st.Sort2(d, v0, v2);
423  v1 = st.Last(d, v0, v1);
424  v1 = st.First(d, v1, v2);
425  return v1;
426 }
427 
428 // Replaces triplets with their median and recurses until less than 3 keys
429 // remain. Ignores leftover values (non-whole triplets)!
430 template <class D, class Traits, typename T>
431 Vec<D> RecursiveMedianOf3(D d, Traits st, T* HWY_RESTRICT keys, size_t num,
432  T* HWY_RESTRICT buf) {
433  const size_t N = Lanes(d);
434  constexpr size_t N1 = st.LanesPerKey();
435 
436  if (num < 3 * N1) return st.SetKey(d, keys);
437 
438  size_t read = 0;
439  size_t written = 0;
440 
441  // Triplets of vectors
442  for (; read + 3 * N <= num; read += 3 * N) {
443  const auto v0 = Load(d, keys + read + 0 * N);
444  const auto v1 = Load(d, keys + read + 1 * N);
445  const auto v2 = Load(d, keys + read + 2 * N);
446  Store(MedianOf3(st, v0, v1, v2), d, buf + written);
447  written += N;
448  }
449 
450  // Triplets of keys
451  for (; read + 3 * N1 <= num; read += 3 * N1) {
452  const auto v0 = st.SetKey(d, keys + read + 0 * N1);
453  const auto v1 = st.SetKey(d, keys + read + 1 * N1);
454  const auto v2 = st.SetKey(d, keys + read + 2 * N1);
455  StoreU(MedianOf3(st, v0, v1, v2), d, buf + written);
456  written += N1;
457  }
458 
459  // Tail recursion; swap buffers
460  return RecursiveMedianOf3(d, st, buf, written, keys);
461 }
462 
463 #if VQSORT_SECURE_RNG
464 using Generator = absl::BitGen;
465 #else
466 // Based on https://github.com/numpy/numpy/issues/16313#issuecomment-641897028
467 #pragma pack(push, 1)
468 class Generator {
469  public:
470  Generator(const void* heap, size_t num) {
471  Sorter::Fill24Bytes(heap, num, &a_);
472  k_ = 1; // stream index: must be odd
473  }
474 
475  uint64_t operator()() {
476  const uint64_t b = b_;
477  w_ += k_;
478  const uint64_t next = a_ ^ w_;
479  a_ = (b + (b << 3)) ^ (b >> 11);
480  const uint64_t rot = (b << 24) | (b >> 40);
481  b_ = rot + next;
482  return next;
483  }
484 
485  private:
486  uint64_t a_;
487  uint64_t b_;
488  uint64_t w_;
489  uint64_t k_; // increment
490 };
491 #pragma pack(pop)
492 
493 #endif // !VQSORT_SECURE_RNG
494 
495 // Returns slightly biased random index of a chunk in [0, num_chunks).
496 // See https://www.pcg-random.org/posts/bounded-rands.html.
497 HWY_INLINE size_t RandomChunkIndex(const uint32_t num_chunks, uint32_t bits) {
498  const uint64_t chunk_index = (static_cast<uint64_t>(bits) * num_chunks) >> 32;
499  HWY_DASSERT(chunk_index < num_chunks);
500  return static_cast<size_t>(chunk_index);
501 }
502 
503 template <class D, class Traits, typename T>
504 HWY_NOINLINE Vec<D> ChoosePivot(D d, Traits st, T* HWY_RESTRICT keys,
505  const size_t begin, const size_t end,
506  T* HWY_RESTRICT buf, Generator& rng) {
507  using V = decltype(Zero(d));
508  const size_t N = Lanes(d);
509 
510  // Power of two
511  const size_t lanes_per_chunk = Constants::LanesPerChunk(sizeof(T), N);
512 
513  keys += begin;
514  size_t num = end - begin;
515 
516  // Align start of keys to chunks. We always have at least 2 chunks because the
517  // base case would have handled anything up to 16 vectors, i.e. >= 4 chunks.
518  HWY_DASSERT(num >= 2 * lanes_per_chunk);
519  const size_t misalign =
520  (reinterpret_cast<uintptr_t>(keys) / sizeof(T)) & (lanes_per_chunk - 1);
521  if (misalign != 0) {
522  const size_t consume = lanes_per_chunk - misalign;
523  keys += consume;
524  num -= consume;
525  }
526 
527  // Generate enough random bits for 9 uint32
528  uint64_t* bits64 = reinterpret_cast<uint64_t*>(buf);
529  for (size_t i = 0; i < 5; ++i) {
530  bits64[i] = rng();
531  }
532  const uint32_t* bits = reinterpret_cast<const uint32_t*>(buf);
533 
534  const uint32_t lpc32 = static_cast<uint32_t>(lanes_per_chunk);
535  // Avoid division
536  const size_t log2_lpc = Num0BitsBelowLS1Bit_Nonzero32(lpc32);
537  const size_t num_chunks64 = num >> log2_lpc;
538  // Clamp to uint32 for RandomChunkIndex
539  const uint32_t num_chunks =
540  static_cast<uint32_t>(HWY_MIN(num_chunks64, 0xFFFFFFFFull));
541 
542  const size_t offset0 = RandomChunkIndex(num_chunks, bits[0]) << log2_lpc;
543  const size_t offset1 = RandomChunkIndex(num_chunks, bits[1]) << log2_lpc;
544  const size_t offset2 = RandomChunkIndex(num_chunks, bits[2]) << log2_lpc;
545  const size_t offset3 = RandomChunkIndex(num_chunks, bits[3]) << log2_lpc;
546  const size_t offset4 = RandomChunkIndex(num_chunks, bits[4]) << log2_lpc;
547  const size_t offset5 = RandomChunkIndex(num_chunks, bits[5]) << log2_lpc;
548  const size_t offset6 = RandomChunkIndex(num_chunks, bits[6]) << log2_lpc;
549  const size_t offset7 = RandomChunkIndex(num_chunks, bits[7]) << log2_lpc;
550  const size_t offset8 = RandomChunkIndex(num_chunks, bits[8]) << log2_lpc;
551  for (size_t i = 0; i < lanes_per_chunk; i += N) {
552  const V v0 = Load(d, keys + offset0 + i);
553  const V v1 = Load(d, keys + offset1 + i);
554  const V v2 = Load(d, keys + offset2 + i);
555  const V medians0 = MedianOf3(st, v0, v1, v2);
556  Store(medians0, d, buf + i);
557 
558  const V v3 = Load(d, keys + offset3 + i);
559  const V v4 = Load(d, keys + offset4 + i);
560  const V v5 = Load(d, keys + offset5 + i);
561  const V medians1 = MedianOf3(st, v3, v4, v5);
562  Store(medians1, d, buf + i + lanes_per_chunk);
563 
564  const V v6 = Load(d, keys + offset6 + i);
565  const V v7 = Load(d, keys + offset7 + i);
566  const V v8 = Load(d, keys + offset8 + i);
567  const V medians2 = MedianOf3(st, v6, v7, v8);
568  Store(medians2, d, buf + i + lanes_per_chunk * 2);
569  }
570 
571  return RecursiveMedianOf3(d, st, buf, 3 * lanes_per_chunk,
572  buf + 3 * lanes_per_chunk);
573 }
574 
575 // Compute exact min/max to detect all-equal partitions. Only called after a
576 // degenerate Partition (none in the right partition).
577 template <class D, class Traits, typename T>
578 HWY_NOINLINE void ScanMinMax(D d, Traits st, const T* HWY_RESTRICT keys,
579  size_t num, T* HWY_RESTRICT buf, Vec<D>& first,
580  Vec<D>& last) {
581  const size_t N = Lanes(d);
582 
583  first = st.LastValue(d);
584  last = st.FirstValue(d);
585 
586  size_t i = 0;
587  for (; i + N <= num; i += N) {
588  const Vec<D> v = LoadU(d, keys + i);
589  first = st.First(d, v, first);
590  last = st.Last(d, v, last);
591  }
592  if (HWY_LIKELY(i != num)) {
593  HWY_DASSERT(num >= N); // See HandleSpecialCases
594  const Vec<D> v = LoadU(d, keys + num - N);
595  first = st.First(d, v, first);
596  last = st.Last(d, v, last);
597  }
598 
599  first = st.FirstOfLanes(d, first, buf);
600  last = st.LastOfLanes(d, last, buf);
601 }
602 
603 template <class D, class Traits, typename T>
604 void Recurse(D d, Traits st, T* HWY_RESTRICT keys, const size_t begin,
605  const size_t end, const Vec<D> pivot, T* HWY_RESTRICT buf,
606  Generator& rng, size_t remaining_levels) {
607  HWY_DASSERT(begin + 1 < end);
608  const size_t num = end - begin; // >= 2
609 
610  // Too many degenerate partitions. This is extremely unlikely to happen
611  // because we select pivots from large (though still O(1)) samples.
612  if (HWY_UNLIKELY(remaining_levels == 0)) {
613  HeapSort(st, keys + begin, num); // Slow but N*logN.
614  return;
615  }
616 
617  const ptrdiff_t base_case_num =
618  static_cast<ptrdiff_t>(Constants::BaseCaseNum(Lanes(d)));
619  const size_t bound = Partition(d, st, keys, begin, end, pivot, buf);
620 
621  const ptrdiff_t num_left =
622  static_cast<ptrdiff_t>(bound) - static_cast<ptrdiff_t>(begin);
623  const ptrdiff_t num_right =
624  static_cast<ptrdiff_t>(end) - static_cast<ptrdiff_t>(bound);
625 
626  // Check for degenerate partitions (i.e. Partition did not move any keys):
627  if (HWY_UNLIKELY(num_right == 0)) {
628  // Because the pivot is one of the keys, it must have been equal to the
629  // first or last key in sort order. Scan for the actual min/max:
630  // passing the current pivot as the new bound is insufficient because one of
631  // the partitions might not actually include that key.
632  Vec<D> first, last;
633  ScanMinMax(d, st, keys + begin, num, buf, first, last);
634  if (AllTrue(d, Eq(first, last))) return;
635 
636  // Separate recursion to make sure that we don't pick `last` as the
637  // pivot - that would again lead to a degenerate partition.
638  Recurse(d, st, keys, begin, end, first, buf, rng, remaining_levels - 1);
639  return;
640  }
641 
642  if (HWY_UNLIKELY(num_left <= base_case_num)) {
643  BaseCase(d, st, keys + begin, static_cast<size_t>(num_left), buf);
644  } else {
645  const Vec<D> next_pivot = ChoosePivot(d, st, keys, begin, bound, buf, rng);
646  Recurse(d, st, keys, begin, bound, next_pivot, buf, rng,
647  remaining_levels - 1);
648  }
649  if (HWY_UNLIKELY(num_right <= base_case_num)) {
650  BaseCase(d, st, keys + bound, static_cast<size_t>(num_right), buf);
651  } else {
652  const Vec<D> next_pivot = ChoosePivot(d, st, keys, bound, end, buf, rng);
653  Recurse(d, st, keys, bound, end, next_pivot, buf, rng,
654  remaining_levels - 1);
655  }
656 }
657 
658 // Returns true if sorting is finished.
659 template <class D, class Traits, typename T>
660 bool HandleSpecialCases(D d, Traits st, T* HWY_RESTRICT keys, size_t num,
661  T* HWY_RESTRICT buf) {
662  const size_t N = Lanes(d);
663  const size_t base_case_num = Constants::BaseCaseNum(N);
664 
665  // 128-bit keys require vectors with at least two u64 lanes, which is always
666  // the case unless `d` requests partial vectors (e.g. fraction = 1/2) AND the
667  // hardware vector width is less than 128bit / fraction.
668  const bool partial_128 = N < 2 && st.Is128();
669  // Partition assumes its input is at least two vectors. If vectors are huge,
670  // base_case_num may actually be smaller. If so, which is only possible on
671  // RVV, pass a capped or partial d (LMUL < 1).
672  constexpr bool kPotentiallyHuge =
674  const bool huge_vec = kPotentiallyHuge && (2 * N > base_case_num);
675  if (partial_128 || huge_vec) {
676  // PERFORMANCE WARNING: falling back to HeapSort.
677  HeapSort(st, keys, num);
678  return true;
679  }
680 
681  // Small arrays: use sorting network, no need for other checks.
682  if (HWY_UNLIKELY(num <= base_case_num)) {
683  BaseCase(d, st, keys, num, buf);
684  return true;
685  }
686 
687  // We could also check for already sorted/reverse/equal, but that's probably
688  // counterproductive if vqsort is used as a base case.
689 
690  return false; // not finished sorting
691 }
692 
693 #endif // HWY_TARGET != HWY_SCALAR
694 } // namespace detail
695 
696 // Sorts `keys[0..num-1]` according to the order defined by `st.Compare`.
697 // In-place i.e. O(1) additional storage. Worst-case N*logN comparisons.
698 // Non-stable (order of equal keys may change), except for the common case where
699 // the upper bits of T are the key, and the lower bits are a sequential or at
700 // least unique ID.
701 // There is no upper limit on `num`, but note that pivots may be chosen by
702 // sampling only from the first 256 GiB.
703 //
704 // `d` is typically SortTag<T> (chooses between full and partial vectors).
705 // `st` is SharedTraits<Traits*<Order*>>. This abstraction layer bridges
706 // differences in sort order and single-lane vs 128-bit keys.
707 template <class D, class Traits, typename T>
708 void Sort(D d, Traits st, T* HWY_RESTRICT keys, size_t num,
709  T* HWY_RESTRICT buf) {
710 #if HWY_TARGET == HWY_SCALAR
711  (void)d;
712  (void)buf;
713  // PERFORMANCE WARNING: vqsort is not enabled for the non-SIMD target
714  return detail::HeapSort(st, keys, num);
715 #else
716 #if !HWY_HAVE_SCALABLE
717  // On targets with fixed-size vectors, avoid _using_ the allocated memory.
718  // We avoid (potentially expensive for small input sizes) allocations on
719  // platforms where no targets are scalable. For 512-bit vectors, this fits on
720  // the stack (several KiB).
721  HWY_ALIGN T storage[SortConstants::BufNum<T>(HWY_LANES(T))] = {};
722  static_assert(sizeof(storage) <= 8192, "Unexpectedly large, check size");
723  buf = storage;
724 #endif // !HWY_HAVE_SCALABLE
725 
726  if (detail::HandleSpecialCases(d, st, keys, num, buf)) return;
727 
728 #if HWY_MAX_BYTES > 64
729  // sorting_networks-inl and traits assume no more than 512 bit vectors.
730  if (Lanes(d) > 64 / sizeof(T)) {
731  return Sort(CappedTag<T, 64 / sizeof(T)>(), st, keys, num, buf);
732  }
733 #endif // HWY_MAX_BYTES > 64
734 
735  // Pulled out of the recursion so we can special-case degenerate partitions.
736  detail::Generator rng(keys, num);
737  const Vec<D> pivot = detail::ChoosePivot(d, st, keys, 0, num, buf, rng);
738 
739  // Introspection: switch to worst-case N*logN heapsort after this many.
740  const size_t max_levels = 2 * hwy::CeilLog2(num) + 4;
741 
742  detail::Recurse(d, st, keys, 0, num, pivot, buf, rng, max_levels);
743 #endif // HWY_TARGET == HWY_SCALAR
744 }
745 
746 // NOLINTNEXTLINE(google-readability-namespace-comments)
747 } // namespace HWY_NAMESPACE
748 } // namespace hwy
750 
751 #endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
#define HWY_MAX(a, b)
Definition: base.h:128
#define HWY_RESTRICT
Definition: base.h:63
#define HWY_NOINLINE
Definition: base.h:65
#define HWY_MIN(a, b)
Definition: base.h:127
#define HWY_INLINE
Definition: base.h:64
#define HWY_DASSERT(condition)
Definition: base.h:193
#define HWY_LIKELY(expr)
Definition: base.h:68
#define HWY_UNLIKELY(expr)
Definition: base.h:69
static void Fill24Bytes(const void *seed_heap, size_t seed_num, void *bytes)
void HeapSort(Traits st, T *HWY_RESTRICT keys, const size_t num)
Definition: vqsort-inl.h:71
HWY_INLINE void SortingNetwork(Traits st, T *HWY_RESTRICT buf, size_t cols)
Definition: sorting_networks-inl.h:603
HWY_INLINE bool AllTrue(hwy::SizeTag< 1 >, const Mask128< T > m)
Definition: wasm_128-inl.h:3111
HWY_INLINE bool AllFalse(hwy::SizeTag< 1 >, const Mask256< T > mask)
Definition: x86_256-inl.h:4066
HWY_INLINE Mask128< T, N > Xor(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:936
HWY_INLINE size_t CountTrue(hwy::SizeTag< 1 >, const Mask128< T > mask)
Definition: arm_neon-inl.h:4680
HWY_INLINE Vec128< T, N > Compress(Vec128< T, N > v, const uint64_t mask_bits)
Definition: arm_neon-inl.h:5020
void Swap(T *a, T *b)
Definition: vqsort-inl.h:63
HWY_INLINE Mask128< T, N > And(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:825
HWY_INLINE Mask512< T > Not(hwy::SizeTag< 1 >, const Mask512< T > m)
Definition: x86_512-inl.h:1553
hwy::SortConstants Constants
Definition: sorting_networks-inl.h:34
HWY_INLINE Mask128< T, N > AndNot(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:862
d
Definition: rvv-inl.h:1656
HWY_API auto Eq(V a, V b) -> decltype(a==b)
Definition: arm_neon-inl.h:5244
HWY_API Mask128< T, N > FirstN(const Simd< T, N, 0 > d, size_t num)
Definition: arm_neon-inl.h:1896
typename detail::CappedTagChecker< T, kLimit >::type CappedTag
Definition: ops/shared-inl.h:173
HWY_API Vec128< T, N > Load(Simd< T, N, 0 > d, const T *HWY_RESTRICT p)
Definition: arm_neon-inl.h:2205
HWY_API Vec128< T, N > Zero(Simd< T, N, 0 > d)
Definition: arm_neon-inl.h:733
HWY_API size_t Lanes(Simd< T, N, kPow2 > d)
Definition: arm_sve-inl.h:218
void Sort(D d, Traits st, T *HWY_RESTRICT keys, size_t num, T *HWY_RESTRICT buf)
Definition: vqsort-inl.h:708
HWY_API void StoreU(const Vec128< uint8_t > v, Full128< uint8_t >, uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2224
HWY_API Vec128< uint8_t > LoadU(Full128< uint8_t >, const uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2031
HWY_API size_t CompressBlendedStore(Vec128< T, N > v, Mask128< T, N > m, Simd< T, N, 0 > d, T *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:5061
typename detail::FixedTagChecker< T, kNumLanes >::type FixedTag
Definition: ops/shared-inl.h:189
HWY_API void SafeCopyN(const size_t num, D d, const T *HWY_RESTRICT from, T *HWY_RESTRICT to)
Definition: generic_ops-inl.h:79
N
Definition: rvv-inl.h:1656
HWY_API size_t CompressStore(Vec128< T, N > v, const Mask128< T, N > mask, Simd< T, N, 0 > d, T *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:5052
HWY_API void Store(Vec128< T, N > v, Simd< T, N, 0 > d, T *HWY_RESTRICT aligned)
Definition: arm_neon-inl.h:2397
const vfloat64m1_t v
Definition: rvv-inl.h:1656
decltype(Zero(D())) Vec
Definition: generic_ops-inl.h:32
Definition: aligned_allocator.h:27
HWY_INLINE HWY_ATTR_CACHE void Prefetch(const T *p)
Definition: cache_control.h:77
HWY_API size_t Num0BitsAboveMS1Bit_Nonzero32(const uint32_t x)
Definition: base.h:633
HWY_API size_t Num0BitsBelowLS1Bit_Nonzero32(const uint32_t x)
Definition: base.h:598
constexpr size_t CeilLog2(TI x)
Definition: base.h:700
#define HWY_MAX_BYTES
Definition: set_macros-inl.h:82
#define HWY_LANES(T)
Definition: set_macros-inl.h:83
#define HWY_ALIGN
Definition: set_macros-inl.h:81
#define HWY_NAMESPACE
Definition: set_macros-inl.h:80
Definition: arm_neon-inl.h:4797
Definition: contrib/sort/shared-inl.h:28
static constexpr size_t kMaxCols
Definition: contrib/sort/shared-inl.h:34
static constexpr size_t kMaxRows
Definition: contrib/sort/shared-inl.h:43
static constexpr HWY_INLINE size_t BaseCaseNum(size_t N)
Definition: contrib/sort/shared-inl.h:45
static constexpr size_t kMaxRowsLog2
Definition: contrib/sort/shared-inl.h:42
static constexpr size_t kPartitionUnroll
Definition: contrib/sort/shared-inl.h:54
static constexpr HWY_INLINE size_t LanesPerChunk(size_t sizeof_t, size_t N)
Definition: contrib/sort/shared-inl.h:68
HWY_AFTER_NAMESPACE()
HWY_BEFORE_NAMESPACE()