Grok 10.0.0
traits128-inl.h
Go to the documentation of this file.
1// Copyright 2021 Google LLC
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16// Per-target
17#if defined(HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE) == \
18 defined(HWY_TARGET_TOGGLE)
19#ifdef HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE
20#undef HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE
21#else
22#define HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE
23#endif
24
25#include <string>
26
28#include "hwy/contrib/sort/vqsort.h" // SortDescending
29#include "hwy/highway.h"
30
32namespace hwy {
33namespace HWY_NAMESPACE {
34namespace detail {
35
36#if VQSORT_ENABLED || HWY_IDE
37
38// Highway does not provide a lane type for 128-bit keys, so we use uint64_t
39// along with an abstraction layer for single-lane vs. lane-pair, which is
40// independent of the order.
41struct KeyAny128 {
42 constexpr bool Is128() const { return true; }
43 constexpr size_t LanesPerKey() const { return 2; }
44
45 // What type bench_sort should allocate for generating inputs.
46 using LaneType = uint64_t;
47 // KeyType and KeyString are defined by derived classes.
48
49 HWY_INLINE void Swap(LaneType* a, LaneType* b) const {
50 const FixedTag<LaneType, 2> d;
51 const auto temp = LoadU(d, a);
52 StoreU(LoadU(d, b), d, a);
53 StoreU(temp, d, b);
54 }
55
56 template <class V, class M>
57 HWY_INLINE V CompressKeys(V keys, M mask) const {
58 return CompressBlocksNot(keys, mask);
59 }
60
61 template <class D>
62 HWY_INLINE Vec<D> SetKey(D d, const TFromD<D>* key) const {
63 return LoadDup128(d, key);
64 }
65
66 template <class D>
67 HWY_INLINE Vec<D> ReverseKeys(D d, Vec<D> v) const {
68 return ReverseBlocks(d, v);
69 }
70
71 template <class D>
72 HWY_INLINE Vec<D> ReverseKeys2(D /* tag */, const Vec<D> v) const {
73 return SwapAdjacentBlocks(v);
74 }
75
76 // Only called for 4 keys because we do not support >512-bit vectors.
77 template <class D>
78 HWY_INLINE Vec<D> ReverseKeys4(D d, const Vec<D> v) const {
79 HWY_DASSERT(Lanes(d) <= 64 / sizeof(TFromD<D>));
80 return ReverseKeys(d, v);
81 }
82
83 // Only called for 4 keys because we do not support >512-bit vectors.
84 template <class D>
85 HWY_INLINE Vec<D> OddEvenPairs(D d, const Vec<D> odd,
86 const Vec<D> even) const {
87 HWY_DASSERT(Lanes(d) <= 64 / sizeof(TFromD<D>));
88 return ConcatUpperLower(d, odd, even);
89 }
90
91 template <class V>
92 HWY_INLINE V OddEvenKeys(const V odd, const V even) const {
93 return OddEvenBlocks(odd, even);
94 }
95
96 template <class D>
97 HWY_INLINE Vec<D> ReverseKeys8(D, Vec<D>) const {
98 HWY_ASSERT(0); // not supported: would require 1024-bit vectors
99 }
100
101 template <class D>
102 HWY_INLINE Vec<D> ReverseKeys16(D, Vec<D>) const {
103 HWY_ASSERT(0); // not supported: would require 2048-bit vectors
104 }
105
106 // This is only called for 8/16 col networks (not supported).
107 template <class D>
108 HWY_INLINE Vec<D> SwapAdjacentPairs(D, Vec<D>) const {
109 HWY_ASSERT(0);
110 }
111
112 // This is only called for 16 col networks (not supported).
113 template <class D>
114 HWY_INLINE Vec<D> SwapAdjacentQuads(D, Vec<D>) const {
115 HWY_ASSERT(0);
116 }
117
118 // This is only called for 8 col networks (not supported).
119 template <class D>
120 HWY_INLINE Vec<D> OddEvenQuads(D, Vec<D>, Vec<D>) const {
121 HWY_ASSERT(0);
122 }
123};
124
125// Base class shared between OrderAscending128, OrderDescending128.
126struct Key128 : public KeyAny128 {
127 // What type to pass to Sorter::operator().
128 using KeyType = hwy::uint128_t;
129
130 std::string KeyString() const { return "U128"; }
131};
132
133// Anything order-related depends on the key traits *and* the order (see
134// FirstOfLanes). We cannot implement just one Compare function because Lt128
135// only compiles if the lane type is u64. Thus we need either overloaded
136// functions with a tag type, class specializations, or separate classes.
137// We avoid overloaded functions because we want all functions to be callable
138// from a SortTraits without per-function wrappers. Specializing would work, but
139// we are anyway going to specialize at a higher level.
140struct OrderAscending128 : public Key128 {
141 using Order = SortAscending;
142
143 HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) {
144 return (a[1] == b[1]) ? a[0] < b[0] : a[1] < b[1];
145 }
146
147 template <class D>
148 HWY_INLINE Mask<D> Compare(D d, Vec<D> a, Vec<D> b) const {
149 return Lt128(d, a, b);
150 }
151
152 // Used by CompareTop
153 template <class V>
154 HWY_INLINE Mask<DFromV<V> > CompareLanes(V a, V b) const {
155 return Lt(a, b);
156 }
157
158 template <class D>
159 HWY_INLINE Vec<D> First(D d, const Vec<D> a, const Vec<D> b) const {
160 return Min128(d, a, b);
161 }
162
163 template <class D>
164 HWY_INLINE Vec<D> Last(D d, const Vec<D> a, const Vec<D> b) const {
165 return Max128(d, a, b);
166 }
167
168 // Same as for regular lanes because 128-bit lanes are u64.
169 template <class D>
170 HWY_INLINE Vec<D> FirstValue(D d) const {
171 return Set(d, hwy::LowestValue<TFromD<D> >());
172 }
173
174 template <class D>
175 HWY_INLINE Vec<D> LastValue(D d) const {
176 return Set(d, hwy::HighestValue<TFromD<D> >());
177 }
178};
179
180struct OrderDescending128 : public Key128 {
181 using Order = SortDescending;
182
183 HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) {
184 return (a[1] == b[1]) ? b[0] < a[0] : b[1] < a[1];
185 }
186
187 template <class D>
188 HWY_INLINE Mask<D> Compare(D d, Vec<D> a, Vec<D> b) const {
189 return Lt128(d, b, a);
190 }
191
192 // Used by CompareTop
193 template <class V>
194 HWY_INLINE Mask<DFromV<V> > CompareLanes(V a, V b) const {
195 return Lt(b, a);
196 }
197
198 template <class D>
199 HWY_INLINE Vec<D> First(D d, const Vec<D> a, const Vec<D> b) const {
200 return Max128(d, a, b);
201 }
202
203 template <class D>
204 HWY_INLINE Vec<D> Last(D d, const Vec<D> a, const Vec<D> b) const {
205 return Min128(d, a, b);
206 }
207
208 // Same as for regular lanes because 128-bit lanes are u64.
209 template <class D>
210 HWY_INLINE Vec<D> FirstValue(D d) const {
211 return Set(d, hwy::HighestValue<TFromD<D> >());
212 }
213
214 template <class D>
215 HWY_INLINE Vec<D> LastValue(D d) const {
216 return Set(d, hwy::LowestValue<TFromD<D> >());
217 }
218};
219
220// Base class shared between OrderAscendingKV128, OrderDescendingKV128.
221struct KeyValue128 : public KeyAny128 {
222 // What type to pass to Sorter::operator().
223 using KeyType = K64V64;
224
225 std::string KeyString() const { return "KV128"; }
226};
227
228struct OrderAscendingKV128 : public KeyValue128 {
229 using Order = SortAscending;
230
231 HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) {
232 return a[1] < b[1];
233 }
234
235 template <class D>
236 HWY_INLINE Mask<D> Compare(D d, Vec<D> a, Vec<D> b) const {
237 return Lt128Upper(d, a, b);
238 }
239
240 // Used by CompareTop
241 template <class V>
242 HWY_INLINE Mask<DFromV<V> > CompareLanes(V a, V b) const {
243 return Lt(a, b);
244 }
245
246 template <class D>
247 HWY_INLINE Vec<D> First(D d, const Vec<D> a, const Vec<D> b) const {
248 return Min128Upper(d, a, b);
249 }
250
251 template <class D>
252 HWY_INLINE Vec<D> Last(D d, const Vec<D> a, const Vec<D> b) const {
253 return Max128Upper(d, a, b);
254 }
255
256 // Same as for regular lanes because 128-bit lanes are u64.
257 template <class D>
258 HWY_INLINE Vec<D> FirstValue(D d) const {
259 return Set(d, hwy::LowestValue<TFromD<D> >());
260 }
261
262 template <class D>
263 HWY_INLINE Vec<D> LastValue(D d) const {
264 return Set(d, hwy::HighestValue<TFromD<D> >());
265 }
266};
267
268struct OrderDescendingKV128 : public KeyValue128 {
269 using Order = SortDescending;
270
271 HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) {
272 return b[1] < a[1];
273 }
274
275 template <class D>
276 HWY_INLINE Mask<D> Compare(D d, Vec<D> a, Vec<D> b) const {
277 return Lt128Upper(d, b, a);
278 }
279
280 // Used by CompareTop
281 template <class V>
282 HWY_INLINE Mask<DFromV<V> > CompareLanes(V a, V b) const {
283 return Lt(b, a);
284 }
285
286 template <class D>
287 HWY_INLINE Vec<D> First(D d, const Vec<D> a, const Vec<D> b) const {
288 return Max128Upper(d, a, b);
289 }
290
291 template <class D>
292 HWY_INLINE Vec<D> Last(D d, const Vec<D> a, const Vec<D> b) const {
293 return Min128Upper(d, a, b);
294 }
295
296 // Same as for regular lanes because 128-bit lanes are u64.
297 template <class D>
298 HWY_INLINE Vec<D> FirstValue(D d) const {
299 return Set(d, hwy::HighestValue<TFromD<D> >());
300 }
301
302 template <class D>
303 HWY_INLINE Vec<D> LastValue(D d) const {
304 return Set(d, hwy::LowestValue<TFromD<D> >());
305 }
306};
307
308// Shared code that depends on Order.
309template <class Base>
310class Traits128 : public Base {
311 // Special case for >= 256 bit vectors
312#if HWY_TARGET <= HWY_AVX2 || HWY_TARGET == HWY_SVE_256
313 // Returns vector with only the top u64 lane valid. Useful when the next step
314 // is to replicate the mask anyway.
315 template <class D>
316 HWY_INLINE HWY_MAYBE_UNUSED Vec<D> CompareTop(D d, Vec<D> a, Vec<D> b) const {
317 const Base* base = static_cast<const Base*>(this);
318 const Mask<D> eqHL = Eq(a, b);
319 const Vec<D> ltHL = VecFromMask(d, base->CompareLanes(a, b));
320#if HWY_TARGET == HWY_SVE_256
321 return IfThenElse(eqHL, DupEven(ltHL), ltHL);
322#else
323 const Vec<D> ltLX = ShiftLeftLanes<1>(ltHL);
324 return OrAnd(ltHL, VecFromMask(d, eqHL), ltLX);
325#endif
326 }
327
328 // We want to swap 2 u128, i.e. 4 u64 lanes, based on the 0 or FF..FF mask in
329 // the most-significant of those lanes (the result of CompareTop), so
330 // replicate it 4x. Only called for >= 256-bit vectors.
331 template <class V>
332 HWY_INLINE V ReplicateTop4x(V v) const {
333#if HWY_TARGET == HWY_SVE_256
334 return svdup_lane_u64(v, 3);
335#elif HWY_TARGET <= HWY_AVX3
336 return V{_mm512_permutex_epi64(v.raw, _MM_SHUFFLE(3, 3, 3, 3))};
337#else // AVX2
338 return V{_mm256_permute4x64_epi64(v.raw, _MM_SHUFFLE(3, 3, 3, 3))};
339#endif
340 }
341#endif // HWY_TARGET
342
343 public:
344 template <class D>
345 HWY_INLINE Vec<D> FirstOfLanes(D d, Vec<D> v,
346 TFromD<D>* HWY_RESTRICT buf) const {
347 const Base* base = static_cast<const Base*>(this);
348 const size_t N = Lanes(d);
349 Store(v, d, buf);
350 v = base->SetKey(d, buf + 0); // result must be broadcasted
351 for (size_t i = base->LanesPerKey(); i < N; i += base->LanesPerKey()) {
352 v = base->First(d, v, base->SetKey(d, buf + i));
353 }
354 return v;
355 }
356
357 template <class D>
358 HWY_INLINE Vec<D> LastOfLanes(D d, Vec<D> v,
359 TFromD<D>* HWY_RESTRICT buf) const {
360 const Base* base = static_cast<const Base*>(this);
361 const size_t N = Lanes(d);
362 Store(v, d, buf);
363 v = base->SetKey(d, buf + 0); // result must be broadcasted
364 for (size_t i = base->LanesPerKey(); i < N; i += base->LanesPerKey()) {
365 v = base->Last(d, v, base->SetKey(d, buf + i));
366 }
367 return v;
368 }
369
370 template <class D>
371 HWY_INLINE void Sort2(D d, Vec<D>& a, Vec<D>& b) const {
372 const Base* base = static_cast<const Base*>(this);
373
374 const Vec<D> a_copy = a;
375 const auto lt = base->Compare(d, a, b);
376 a = IfThenElse(lt, a, b);
377 b = IfThenElse(lt, b, a_copy);
378 }
379
380 // Conditionally swaps even-numbered lanes with their odd-numbered neighbor.
381 template <class D>
382 HWY_INLINE Vec<D> SortPairsDistance1(D d, Vec<D> v) const {
383 const Base* base = static_cast<const Base*>(this);
384 Vec<D> swapped = base->ReverseKeys2(d, v);
385
386#if HWY_TARGET <= HWY_AVX2 || HWY_TARGET == HWY_SVE_256
387 const Vec<D> select = ReplicateTop4x(CompareTop(d, v, swapped));
388 return IfVecThenElse(select, swapped, v);
389#else
390 Sort2(d, v, swapped);
391 return base->OddEvenKeys(swapped, v);
392#endif
393 }
394
395 // Swaps with the vector formed by reversing contiguous groups of 4 keys.
396 template <class D>
397 HWY_INLINE Vec<D> SortPairsReverse4(D d, Vec<D> v) const {
398 const Base* base = static_cast<const Base*>(this);
399 Vec<D> swapped = base->ReverseKeys4(d, v);
400
401 // Only specialize for AVX3 because this requires 512-bit vectors.
402#if HWY_TARGET <= HWY_AVX3
403 const Vec512<uint64_t> outHx = CompareTop(d, v, swapped);
404 // Similar to ReplicateTop4x, we want to gang together 2 comparison results
405 // (4 lanes). They are not contiguous, so use permute to replicate 4x.
406 alignas(64) uint64_t kIndices[8] = {7, 7, 5, 5, 5, 5, 7, 7};
407 const Vec512<uint64_t> select =
408 TableLookupLanes(outHx, SetTableIndices(d, kIndices));
409 return IfVecThenElse(select, swapped, v);
410#else
411 Sort2(d, v, swapped);
412 return base->OddEvenPairs(d, swapped, v);
413#endif
414 }
415
416 // Conditionally swaps lane 0 with 4, 1 with 5 etc.
417 template <class D>
418 HWY_INLINE Vec<D> SortPairsDistance4(D, Vec<D>) const {
419 // Only used by Merge16, which would require 2048 bit vectors (unsupported).
420 HWY_ASSERT(0);
421 }
422};
423
424#endif // VQSORT_ENABLED
425
426} // namespace detail
427// NOLINTNEXTLINE(google-readability-namespace-comments)
428} // namespace HWY_NAMESPACE
429} // namespace hwy
431
432#endif // HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE
#define HWY_RESTRICT
Definition: base.h:61
#define HWY_INLINE
Definition: base.h:62
#define HWY_DASSERT(condition)
Definition: base.h:191
#define HWY_MAYBE_UNUSED
Definition: base.h:73
#define HWY_ASSERT(condition)
Definition: base.h:145
HWY_INLINE Vec128< T, N > IfThenElse(hwy::SizeTag< 1 >, Mask128< T, N > mask, Vec128< T, N > yes, Vec128< T, N > no)
Definition: x86_128-inl.h:673
d
Definition: rvv-inl.h:1742
HWY_API Vec128< T, N > OddEvenBlocks(Vec128< T, N >, Vec128< T, N > even)
Definition: arm_neon-inl.h:4533
HWY_API auto Lt(V a, V b) -> decltype(a==b)
Definition: arm_neon-inl.h:6309
HWY_API auto Eq(V a, V b) -> decltype(a==b)
Definition: arm_neon-inl.h:6301
HWY_API Vec128< uint64_t > CompressBlocksNot(Vec128< uint64_t > v, Mask128< uint64_t >)
Definition: arm_neon-inl.h:5815
HWY_API Vec128< T, N > IfVecThenElse(Vec128< T, N > mask, Vec128< T, N > yes, Vec128< T, N > no)
Definition: arm_neon-inl.h:2006
HWY_API Vec128< T, N > VecFromMask(Simd< T, N, 0 > d, const Mask128< T, N > v)
Definition: arm_neon-inl.h:2182
HWY_API Vec128< T, N > DupEven(Vec128< T, N > v)
Definition: arm_neon-inl.h:4482
HWY_API constexpr size_t Lanes(Simd< T, N, kPow2 >)
Definition: arm_sve-inl.h:236
HWY_API Vec128< T, N > TableLookupLanes(Vec128< T, N > v, Indices128< T, N > idx)
Definition: arm_neon-inl.h:3934
HWY_API void StoreU(const Vec128< uint8_t > v, Full128< uint8_t >, uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2725
HWY_INLINE VFromD< D > Min128Upper(D d, const VFromD< D > a, const VFromD< D > b)
Definition: arm_neon-inl.h:6260
HWY_API Vec128< T, N > SwapAdjacentBlocks(Vec128< T, N > v)
Definition: arm_neon-inl.h:4540
HWY_INLINE VFromD< D > Min128(D d, const VFromD< D > a, const VFromD< D > b)
Definition: arm_neon-inl.h:6250
svuint16_t Set(Simd< bfloat16_t, N, kPow2 > d, bfloat16_t arg)
Definition: arm_sve-inl.h:312
HWY_INLINE VFromD< D > Max128Upper(D d, const VFromD< D > a, const VFromD< D > b)
Definition: arm_neon-inl.h:6265
HWY_INLINE Mask128< T, N > Lt128(Simd< T, N, 0 > d, Vec128< T, N > a, Vec128< T, N > b)
Definition: arm_neon-inl.h:6212
decltype(GetLane(V())) LaneType
Definition: generic_ops-inl.h:25
HWY_API Vec128< uint8_t > LoadU(Full128< uint8_t >, const uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2544
HWY_API Vec128< T, N > OrAnd(Vec128< T, N > o, Vec128< T, N > a1, Vec128< T, N > a2)
Definition: arm_neon-inl.h:1999
HWY_API Vec128< T, N > ConcatUpperLower(Simd< T, N, 0 > d, Vec128< T, N > hi, Vec128< T, N > lo)
Definition: arm_neon-inl.h:4406
HWY_INLINE VFromD< D > Max128(D d, const VFromD< D > a, const VFromD< D > b)
Definition: arm_neon-inl.h:6255
HWY_API Indices128< T, N > SetTableIndices(Simd< T, N, 0 > d, const TI *idx)
Definition: arm_neon-inl.h:3928
HWY_API Vec128< T, N > LoadDup128(Simd< T, N, 0 > d, const T *const HWY_RESTRICT p)
Definition: arm_neon-inl.h:2718
N
Definition: rvv-inl.h:1742
HWY_API Vec128< T > ReverseBlocks(Full128< T >, const Vec128< T > v)
Definition: arm_neon-inl.h:4548
HWY_API void Store(Vec128< T, N > v, Simd< T, N, 0 > d, T *HWY_RESTRICT aligned)
Definition: arm_neon-inl.h:2882
HWY_INLINE Mask128< T, N > Lt128Upper(Simd< T, N, 0 > d, Vec128< T, N > a, Vec128< T, N > b)
Definition: arm_neon-inl.h:6240
const vfloat64m1_t v
Definition: rvv-inl.h:1742
Definition: aligned_allocator.h:27
HWY_API constexpr T HighestValue()
Definition: base.h:576
HWY_API constexpr T LowestValue()
Definition: base.h:563
#define HWY_NAMESPACE
Definition: set_macros-inl.h:82
Definition: base.h:264
HWY_AFTER_NAMESPACE()
HWY_BEFORE_NAMESPACE()