Grok 10.0.0
traits-inl.h
Go to the documentation of this file.
1// Copyright 2021 Google LLC
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16// Per-target
17#if defined(HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE) == \
18 defined(HWY_TARGET_TOGGLE)
19#ifdef HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE
20#undef HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE
21#else
22#define HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE
23#endif
24
25#include <string>
26
27#include "hwy/contrib/sort/shared-inl.h" // SortConstants
28#include "hwy/contrib/sort/vqsort.h" // SortDescending
29#include "hwy/highway.h"
30#include "hwy/print.h"
31
33namespace hwy {
34namespace HWY_NAMESPACE {
35namespace detail {
36
37#if VQSORT_ENABLED || HWY_IDE
38
39// Highway does not provide a lane type for 128-bit keys, so we use uint64_t
40// along with an abstraction layer for single-lane vs. lane-pair, which is
41// independent of the order.
42template <typename T>
43struct KeyLane {
44 constexpr bool Is128() const { return false; }
45 constexpr size_t LanesPerKey() const { return 1; }
46
47 // What type bench_sort should allocate for generating inputs.
48 using LaneType = T;
49 // What type to pass to Sorter::operator().
50 using KeyType = T;
51
52 std::string KeyString() const {
53 char string100[100];
54 hwy::detail::TypeName(hwy::detail::MakeTypeInfo<KeyType>(), 1, string100);
55 return string100;
56 }
57
58 // For HeapSort
59 HWY_INLINE void Swap(T* a, T* b) const {
60 const T temp = *a;
61 *a = *b;
62 *b = temp;
63 }
64
65 template <class V, class M>
66 HWY_INLINE V CompressKeys(V keys, M mask) const {
67 return CompressNot(keys, mask);
68 }
69
70 // Broadcasts one key into a vector
71 template <class D>
72 HWY_INLINE Vec<D> SetKey(D d, const T* key) const {
73 return Set(d, *key);
74 }
75
76 template <class D>
77 HWY_INLINE Vec<D> ReverseKeys(D d, Vec<D> v) const {
78 return Reverse(d, v);
79 }
80
81 template <class D>
82 HWY_INLINE Vec<D> ReverseKeys2(D d, Vec<D> v) const {
83 return Reverse2(d, v);
84 }
85
86 template <class D>
87 HWY_INLINE Vec<D> ReverseKeys4(D d, Vec<D> v) const {
88 return Reverse4(d, v);
89 }
90
91 template <class D>
92 HWY_INLINE Vec<D> ReverseKeys8(D d, Vec<D> v) const {
93 return Reverse8(d, v);
94 }
95
96 template <class D>
97 HWY_INLINE Vec<D> ReverseKeys16(D d, Vec<D> v) const {
98 static_assert(SortConstants::kMaxCols <= 16, "Assumes u32x16 = 512 bit");
99 return ReverseKeys(d, v);
100 }
101
102 template <class V>
103 HWY_INLINE V OddEvenKeys(const V odd, const V even) const {
104 return OddEven(odd, even);
105 }
106
107 template <class D, HWY_IF_LANE_SIZE_D(D, 2)>
108 HWY_INLINE Vec<D> SwapAdjacentPairs(D d, const Vec<D> v) const {
109 const Repartition<uint32_t, D> du32;
110 return BitCast(d, Shuffle2301(BitCast(du32, v)));
111 }
112 template <class D, HWY_IF_LANE_SIZE_D(D, 4)>
113 HWY_INLINE Vec<D> SwapAdjacentPairs(D /* tag */, const Vec<D> v) const {
114 return Shuffle1032(v);
115 }
116 template <class D, HWY_IF_LANE_SIZE_D(D, 8)>
117 HWY_INLINE Vec<D> SwapAdjacentPairs(D /* tag */, const Vec<D> v) const {
118 return SwapAdjacentBlocks(v);
119 }
120
121 template <class D, HWY_IF_NOT_LANE_SIZE_D(D, 8)>
122 HWY_INLINE Vec<D> SwapAdjacentQuads(D d, const Vec<D> v) const {
123#if HWY_HAVE_FLOAT64 // in case D is float32
124 const RepartitionToWide<D> dw;
125#else
126 const RepartitionToWide<RebindToUnsigned<D>> dw;
127#endif
128 return BitCast(d, SwapAdjacentPairs(dw, BitCast(dw, v)));
129 }
130 template <class D, HWY_IF_LANE_SIZE_D(D, 8)>
131 HWY_INLINE Vec<D> SwapAdjacentQuads(D d, const Vec<D> v) const {
132 // Assumes max vector size = 512
133 return ConcatLowerUpper(d, v, v);
134 }
135
136 template <class D, HWY_IF_NOT_LANE_SIZE_D(D, 8)>
137 HWY_INLINE Vec<D> OddEvenPairs(D d, const Vec<D> odd,
138 const Vec<D> even) const {
139#if HWY_HAVE_FLOAT64 // in case D is float32
140 const RepartitionToWide<D> dw;
141#else
142 const RepartitionToWide<RebindToUnsigned<D>> dw;
143#endif
144 return BitCast(d, OddEven(BitCast(dw, odd), BitCast(dw, even)));
145 }
146 template <class D, HWY_IF_LANE_SIZE_D(D, 8)>
147 HWY_INLINE Vec<D> OddEvenPairs(D /* tag */, Vec<D> odd, Vec<D> even) const {
148 return OddEvenBlocks(odd, even);
149 }
150
151 template <class D, HWY_IF_NOT_LANE_SIZE_D(D, 8)>
152 HWY_INLINE Vec<D> OddEvenQuads(D d, Vec<D> odd, Vec<D> even) const {
153#if HWY_HAVE_FLOAT64 // in case D is float32
154 const RepartitionToWide<D> dw;
155#else
156 const RepartitionToWide<RebindToUnsigned<D>> dw;
157#endif
158 return BitCast(d, OddEvenPairs(dw, BitCast(dw, odd), BitCast(dw, even)));
159 }
160 template <class D, HWY_IF_LANE_SIZE_D(D, 8)>
161 HWY_INLINE Vec<D> OddEvenQuads(D d, Vec<D> odd, Vec<D> even) const {
162 return ConcatUpperLower(d, odd, even);
163 }
164};
165
166// Anything order-related depends on the key traits *and* the order (see
167// FirstOfLanes). We cannot implement just one Compare function because Lt128
168// only compiles if the lane type is u64. Thus we need either overloaded
169// functions with a tag type, class specializations, or separate classes.
170// We avoid overloaded functions because we want all functions to be callable
171// from a SortTraits without per-function wrappers. Specializing would work, but
172// we are anyway going to specialize at a higher level.
173template <typename T>
174struct OrderAscending : public KeyLane<T> {
175 using Order = SortAscending;
176
177 HWY_INLINE bool Compare1(const T* a, const T* b) {
178 return *a < *b;
179 }
180
181 template <class D>
182 HWY_INLINE Mask<D> Compare(D /* tag */, Vec<D> a, Vec<D> b) const {
183 return Lt(a, b);
184 }
185
186 // Two halves of Sort2, used in ScanMinMax.
187 template <class D>
188 HWY_INLINE Vec<D> First(D /* tag */, const Vec<D> a, const Vec<D> b) const {
189 return Min(a, b);
190 }
191
192 template <class D>
193 HWY_INLINE Vec<D> Last(D /* tag */, const Vec<D> a, const Vec<D> b) const {
194 return Max(a, b);
195 }
196
197 template <class D>
198 HWY_INLINE Vec<D> FirstOfLanes(D d, Vec<D> v,
199 T* HWY_RESTRICT /* buf */) const {
200 return MinOfLanes(d, v);
201 }
202
203 template <class D>
204 HWY_INLINE Vec<D> LastOfLanes(D d, Vec<D> v,
205 T* HWY_RESTRICT /* buf */) const {
206 return MaxOfLanes(d, v);
207 }
208
209 template <class D>
210 HWY_INLINE Vec<D> FirstValue(D d) const {
211 return Set(d, hwy::LowestValue<T>());
212 }
213
214 template <class D>
215 HWY_INLINE Vec<D> LastValue(D d) const {
216 return Set(d, hwy::HighestValue<T>());
217 }
218};
219
220template <typename T>
221struct OrderDescending : public KeyLane<T> {
222 using Order = SortDescending;
223
224 HWY_INLINE bool Compare1(const T* a, const T* b) {
225 return *b < *a;
226 }
227
228 template <class D>
229 HWY_INLINE Mask<D> Compare(D /* tag */, Vec<D> a, Vec<D> b) const {
230 return Lt(b, a);
231 }
232
233 template <class D>
234 HWY_INLINE Vec<D> First(D /* tag */, const Vec<D> a, const Vec<D> b) const {
235 return Max(a, b);
236 }
237
238 template <class D>
239 HWY_INLINE Vec<D> Last(D /* tag */, const Vec<D> a, const Vec<D> b) const {
240 return Min(a, b);
241 }
242
243 template <class D>
244 HWY_INLINE Vec<D> FirstOfLanes(D d, Vec<D> v,
245 T* HWY_RESTRICT /* buf */) const {
246 return MaxOfLanes(d, v);
247 }
248
249 template <class D>
250 HWY_INLINE Vec<D> LastOfLanes(D d, Vec<D> v,
251 T* HWY_RESTRICT /* buf */) const {
252 return MinOfLanes(d, v);
253 }
254
255 template <class D>
256 HWY_INLINE Vec<D> FirstValue(D d) const {
257 return Set(d, hwy::HighestValue<T>());
258 }
259
260 template <class D>
261 HWY_INLINE Vec<D> LastValue(D d) const {
262 return Set(d, hwy::LowestValue<T>());
263 }
264};
265
266// Shared code that depends on Order.
267template <class Base>
268struct TraitsLane : public Base {
269 // For each lane i: replaces a[i] with the first and b[i] with the second
270 // according to Base.
271 // Corresponds to a conditional swap, which is one "node" of a sorting
272 // network. Min/Max are cheaper than compare + blend at least for integers.
273 template <class D>
274 HWY_INLINE void Sort2(D d, Vec<D>& a, Vec<D>& b) const {
275 const Base* base = static_cast<const Base*>(this);
276
277 const Vec<D> a_copy = a;
278 // Prior to AVX3, there is no native 64-bit Min/Max, so they compile to 4
279 // instructions. We can reduce it to a compare + 2 IfThenElse.
280#if HWY_AVX3 < HWY_TARGET && HWY_TARGET <= HWY_SSSE3
281 if (sizeof(TFromD<D>) == 8) {
282 const Mask<D> cmp = base->Compare(d, a, b);
283 a = IfThenElse(cmp, a, b);
284 b = IfThenElse(cmp, b, a_copy);
285 return;
286 }
287#endif
288 a = base->First(d, a, b);
289 b = base->Last(d, a_copy, b);
290 }
291
292 // Conditionally swaps even-numbered lanes with their odd-numbered neighbor.
293 template <class D, HWY_IF_LANE_SIZE_D(D, 8)>
294 HWY_INLINE Vec<D> SortPairsDistance1(D d, Vec<D> v) const {
295 const Base* base = static_cast<const Base*>(this);
296 Vec<D> swapped = base->ReverseKeys2(d, v);
297 // Further to the above optimization, Sort2+OddEvenKeys compile to four
298 // instructions; we can save one by combining two blends.
299#if HWY_AVX3 < HWY_TARGET && HWY_TARGET <= HWY_SSSE3
300 const Vec<D> cmp = VecFromMask(d, base->Compare(d, v, swapped));
301 return IfVecThenElse(DupOdd(cmp), swapped, v);
302#else
303 Sort2(d, v, swapped);
304 return base->OddEvenKeys(swapped, v);
305#endif
306 }
307
308 // (See above - we use Sort2 for non-64-bit types.)
309 template <class D, HWY_IF_NOT_LANE_SIZE_D(D, 8)>
310 HWY_INLINE Vec<D> SortPairsDistance1(D d, Vec<D> v) const {
311 const Base* base = static_cast<const Base*>(this);
312 Vec<D> swapped = base->ReverseKeys2(d, v);
313 Sort2(d, v, swapped);
314 return base->OddEvenKeys(swapped, v);
315 }
316
317 // Swaps with the vector formed by reversing contiguous groups of 4 keys.
318 template <class D>
319 HWY_INLINE Vec<D> SortPairsReverse4(D d, Vec<D> v) const {
320 const Base* base = static_cast<const Base*>(this);
321 Vec<D> swapped = base->ReverseKeys4(d, v);
322 Sort2(d, v, swapped);
323 return base->OddEvenPairs(d, swapped, v);
324 }
325
326 // Conditionally swaps lane 0 with 4, 1 with 5 etc.
327 template <class D>
328 HWY_INLINE Vec<D> SortPairsDistance4(D d, Vec<D> v) const {
329 const Base* base = static_cast<const Base*>(this);
330 Vec<D> swapped = base->SwapAdjacentQuads(d, v);
331 // Only used in Merge16, so this will not be used on AVX2 (which only has 4
332 // u64 lanes), so skip the above optimization for 64-bit AVX2.
333 Sort2(d, v, swapped);
334 return base->OddEvenQuads(d, swapped, v);
335 }
336};
337
338#else
339
340// Base class shared between OrderAscending, OrderDescending.
341template <typename T>
342struct KeyLane {
343 constexpr bool Is128() const { return false; }
344 constexpr size_t LanesPerKey() const { return 1; }
345
346 using LaneType = T;
347 using KeyType = T;
348
349 std::string KeyString() const {
350 char string100[100];
351 hwy::detail::TypeName(hwy::detail::MakeTypeInfo<KeyType>(), 1, string100);
352 return string100;
353 }
354};
355
356template <typename T>
357struct OrderAscending : public KeyLane<T> {
359
360 HWY_INLINE bool Compare1(const T* a, const T* b) { return *a < *b; }
361
362 template <class D>
364 return Lt(a, b);
365 }
366};
367
368template <typename T>
369struct OrderDescending : public KeyLane<T> {
371
372 HWY_INLINE bool Compare1(const T* a, const T* b) { return *b < *a; }
373
374 template <class D>
376 return Lt(b, a);
377 }
378};
379
380template <class Order>
381struct TraitsLane : public Order {
382 // For HeapSort
383 template <typename T> // MSVC doesn't find typename Order::LaneType.
384 HWY_INLINE void Swap(T* a, T* b) const {
385 const T temp = *a;
386 *a = *b;
387 *b = temp;
388 }
389
390 template <class D>
391 HWY_INLINE Vec<D> SetKey(D d, const TFromD<D>* key) const {
392 return Set(d, *key);
393 }
394};
395
396#endif // VQSORT_ENABLED
397
398} // namespace detail
399// NOLINTNEXTLINE(google-readability-namespace-comments)
400} // namespace HWY_NAMESPACE
401} // namespace hwy
403
404#endif // HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE
#define HWY_RESTRICT
Definition: base.h:61
#define HWY_INLINE
Definition: base.h:62
HWY_API Vec128< T, N > Shuffle2301(const Vec128< T, N > a, const Vec128< T, N > b)
Definition: wasm_128-inl.h:2425
HWY_INLINE Vec128< T, N > OddEven(hwy::SizeTag< 1 >, const Vec128< T, N > a, const Vec128< T, N > b)
Definition: wasm_128-inl.h:3035
HWY_INLINE Vec128< T, 1 > MinOfLanes(hwy::SizeTag< sizeof(T)>, const Vec128< T, 1 > v)
Definition: arm_neon-inl.h:4804
HWY_INLINE Vec128< T, N > CompressNot(Vec128< T, N > v, const uint64_t mask_bits)
Definition: arm_neon-inl.h:5751
HWY_INLINE Vec128< T, N > IfThenElse(hwy::SizeTag< 1 >, Mask128< T, N > mask, Vec128< T, N > yes, Vec128< T, N > no)
Definition: x86_128-inl.h:673
HWY_INLINE Vec128< T, 1 > MaxOfLanes(hwy::SizeTag< sizeof(T)>, const Vec128< T, 1 > v)
Definition: arm_neon-inl.h:4809
d
Definition: rvv-inl.h:1742
HWY_API Vec128< T, N > OddEvenBlocks(Vec128< T, N >, Vec128< T, N > even)
Definition: arm_neon-inl.h:4533
HWY_API Vec128< T, N > DupOdd(Vec128< T, N > v)
Definition: arm_neon-inl.h:4498
HWY_API Vec128< T > Shuffle1032(const Vec128< T > v)
Definition: arm_neon-inl.h:4046
HWY_API auto Lt(V a, V b) -> decltype(a==b)
Definition: arm_neon-inl.h:6309
HWY_API Vec128< uint64_t, N > Min(const Vec128< uint64_t, N > a, const Vec128< uint64_t, N > b)
Definition: arm_neon-inl.h:2470
HWY_API Vec128< uint64_t, N > Max(const Vec128< uint64_t, N > a, const Vec128< uint64_t, N > b)
Definition: arm_neon-inl.h:2508
HWY_API Vec128< T, N > ConcatLowerUpper(const Simd< T, N, 0 > d, Vec128< T, N > hi, Vec128< T, N > lo)
Definition: arm_neon-inl.h:4380
HWY_API Vec128< T, N > IfVecThenElse(Vec128< T, N > mask, Vec128< T, N > yes, Vec128< T, N > no)
Definition: arm_neon-inl.h:2006
HWY_API Vec128< T, N > VecFromMask(Simd< T, N, 0 > d, const Mask128< T, N > v)
Definition: arm_neon-inl.h:2182
HWY_API Vec128< T, N > SwapAdjacentBlocks(Vec128< T, N > v)
Definition: arm_neon-inl.h:4540
HWY_API Vec128< T, N > Reverse2(Simd< T, N, 0 > d, const Vec128< T, N > v)
Definition: arm_neon-inl.h:3976
svuint16_t Set(Simd< bfloat16_t, N, kPow2 > d, bfloat16_t arg)
Definition: arm_sve-inl.h:312
HWY_API Vec128< T, N > Reverse8(Simd< T, N, 0 > d, const Vec128< T, N > v)
Definition: arm_neon-inl.h:4028
HWY_API Vec128< T, N > ConcatUpperLower(Simd< T, N, 0 > d, Vec128< T, N > hi, Vec128< T, N > lo)
Definition: arm_neon-inl.h:4406
HWY_API Vec128< T, N > BitCast(Simd< T, N, 0 > d, Vec128< FromT, N *sizeof(T)/sizeof(FromT)> v)
Definition: arm_neon-inl.h:988
decltype(MaskFromVec(Zero(D()))) Mask
Definition: generic_ops-inl.h:38
HWY_API Vec128< T, N > Reverse4(Simd< T, N, 0 > d, const Vec128< T, N > v)
Definition: arm_neon-inl.h:4005
HWY_API Vec128< T, 1 > Reverse(Simd< T, 1, 0 >, const Vec128< T, 1 > v)
Definition: arm_neon-inl.h:3945
const vfloat64m1_t v
Definition: rvv-inl.h:1742
typename D::T TFromD
Definition: ops/shared-inl.h:191
decltype(Zero(D())) Vec
Definition: generic_ops-inl.h:32
HWY_DLLEXPORT void TypeName(const TypeInfo &info, size_t N, char *string100)
Definition: aligned_allocator.h:27
#define HWY_NAMESPACE
Definition: set_macros-inl.h:82
Definition: traits-inl.h:342
T LaneType
Definition: traits-inl.h:346
constexpr size_t LanesPerKey() const
Definition: traits-inl.h:344
T KeyType
Definition: traits-inl.h:347
constexpr bool Is128() const
Definition: traits-inl.h:343
std::string KeyString() const
Definition: traits-inl.h:349
Definition: traits-inl.h:357
HWY_INLINE bool Compare1(const T *a, const T *b)
Definition: traits-inl.h:360
HWY_INLINE Mask< D > Compare(D, Vec< D > a, Vec< D > b)
Definition: traits-inl.h:363
SortAscending Order
Definition: traits-inl.h:358
HWY_INLINE bool Compare1(const T *a, const T *b)
Definition: traits-inl.h:372
HWY_INLINE Mask< D > Compare(D, Vec< D > a, Vec< D > b)
Definition: traits-inl.h:375
SortDescending Order
Definition: traits-inl.h:370
Definition: traits-inl.h:381
HWY_INLINE void Swap(T *a, T *b) const
Definition: traits-inl.h:384
HWY_INLINE Vec< D > SetKey(D d, const TFromD< D > *key) const
Definition: traits-inl.h:391
Definition: vqsort.h:32
static constexpr size_t kMaxCols
Definition: contrib/sort/shared-inl.h:34
Definition: vqsort.h:35
HWY_AFTER_NAMESPACE()
HWY_BEFORE_NAMESPACE()