Grok  9.7.5
traits128-inl.h
Go to the documentation of this file.
1 // Copyright 2021 Google LLC
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 // Per-target
17 #if defined(HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE) == \
18  defined(HWY_TARGET_TOGGLE)
19 #ifdef HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE
20 #undef HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE
21 #else
22 #define HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE
23 #endif
24 
25 #include "hwy/contrib/sort/vqsort.h" // SortDescending
26 #include "hwy/highway.h"
27 
29 namespace hwy {
30 namespace HWY_NAMESPACE {
31 namespace detail {
32 
33 #if HWY_TARGET == HWY_SCALAR
34 
37 
38  template <typename T>
39  HWY_INLINE bool Compare1(const T* a, const T* b) {
40  return (a[1] == b[1]) ? a[0] < b[0] : a[1] < b[1];
41  }
42 };
43 
46 
47  template <typename T>
48  HWY_INLINE bool Compare1(const T* a, const T* b) {
49  return (a[1] == b[1]) ? b[0] < a[0] : b[1] < a[1];
50  }
51 };
52 
53 template <class Order>
54 struct Traits128 : public Order {
55  constexpr bool Is128() const { return true; }
56  constexpr size_t LanesPerKey() const { return 2; }
57 };
58 
59 #else
60 
61 // Highway does not provide a lane type for 128-bit keys, so we use uint64_t
62 // along with an abstraction layer for single-lane vs. lane-pair, which is
63 // independent of the order.
64 struct Key128 {
65  constexpr size_t LanesPerKey() const { return 2; }
66 
67  template <typename T>
68  HWY_INLINE void Swap(T* a, T* b) const {
69  const FixedTag<T, 2> d;
70  const auto temp = LoadU(d, a);
71  StoreU(LoadU(d, b), d, a);
72  StoreU(temp, d, b);
73  }
74 
75  template <class D>
76  HWY_INLINE Vec<D> SetKey(D d, const TFromD<D>* key) const {
77  return LoadDup128(d, key);
78  }
79 
80  template <class D>
81  HWY_INLINE Vec<D> ReverseKeys(D d, Vec<D> v) const {
82  return ReverseBlocks(d, v);
83  }
84 
85  template <class D>
86  HWY_INLINE Vec<D> ReverseKeys2(D /* tag */, const Vec<D> v) const {
87  return SwapAdjacentBlocks(v);
88  }
89 
90  // Only called for 4 keys because we do not support >512-bit vectors.
91  template <class D>
92  HWY_INLINE Vec<D> ReverseKeys4(D d, const Vec<D> v) const {
93  HWY_DASSERT(Lanes(d) <= 64 / sizeof(TFromD<D>));
94  return ReverseKeys(d, v);
95  }
96 
97  // Only called for 4 keys because we do not support >512-bit vectors.
98  template <class D>
99  HWY_INLINE Vec<D> OddEvenPairs(D d, const Vec<D> odd,
100  const Vec<D> even) const {
101  HWY_DASSERT(Lanes(d) <= 64 / sizeof(TFromD<D>));
102  return ConcatUpperLower(d, odd, even);
103  }
104 
105  template <class V>
106  HWY_INLINE V OddEvenKeys(const V odd, const V even) const {
107  return OddEvenBlocks(odd, even);
108  }
109 
110  template <class D>
111  HWY_INLINE Vec<D> ReverseKeys8(D, Vec<D>) const {
112  HWY_ASSERT(0); // not supported: would require 1024-bit vectors
113  }
114 
115  template <class D>
116  HWY_INLINE Vec<D> ReverseKeys16(D, Vec<D>) const {
117  HWY_ASSERT(0); // not supported: would require 2048-bit vectors
118  }
119 
120  // This is only called for 8/16 col networks (not supported).
121  template <class D>
122  HWY_INLINE Vec<D> SwapAdjacentPairs(D, Vec<D>) const {
123  HWY_ASSERT(0);
124  }
125 
126  // This is only called for 16 col networks (not supported).
127  template <class D>
128  HWY_INLINE Vec<D> SwapAdjacentQuads(D, Vec<D>) const {
129  HWY_ASSERT(0);
130  }
131 
132  // This is only called for 8 col networks (not supported).
133  template <class D>
134  HWY_INLINE Vec<D> OddEvenQuads(D, Vec<D>, Vec<D>) const {
135  HWY_ASSERT(0);
136  }
137 };
138 
139 // Anything order-related depends on the key traits *and* the order (see
140 // FirstOfLanes). We cannot implement just one Compare function because Lt128
141 // only compiles if the lane type is u64. Thus we need either overloaded
142 // functions with a tag type, class specializations, or separate classes.
143 // We avoid overloaded functions because we want all functions to be callable
144 // from a SortTraits without per-function wrappers. Specializing would work, but
145 // we are anyway going to specialize at a higher level.
146 struct OrderAscending128 : public Key128 {
147  using Order = SortAscending;
148 
149  template <typename T>
150  HWY_INLINE bool Compare1(const T* a, const T* b) {
151  return (a[1] == b[1]) ? a[0] < b[0] : a[1] < b[1];
152  }
153 
154  template <class D>
155  HWY_INLINE Mask<D> Compare(D d, Vec<D> a, Vec<D> b) const {
156  return Lt128(d, a, b);
157  }
158 
159  // Used by CompareTop
160  template <class V>
161  HWY_INLINE Mask<DFromV<V> > CompareLanes(V a, V b) const {
162  return Lt(a, b);
163  }
164 
165  template <class D>
166  HWY_INLINE Vec<D> First(D d, const Vec<D> a, const Vec<D> b) const {
167  return Min128(d, a, b);
168  }
169 
170  template <class D>
171  HWY_INLINE Vec<D> Last(D d, const Vec<D> a, const Vec<D> b) const {
172  return Max128(d, a, b);
173  }
174 
175  template <class D>
176  HWY_INLINE Vec<D> FirstOfLanes(D d, Vec<D> v,
177  TFromD<D>* HWY_RESTRICT buf) const {
178  const size_t N = Lanes(d);
179  Store(v, d, buf);
180  v = SetKey(d, buf + 0); // result must be broadcasted
181  for (size_t i = LanesPerKey(); i < N; i += LanesPerKey()) {
182  v = First(d, v, SetKey(d, buf + i));
183  }
184  return v;
185  }
186 
187  template <class D>
188  HWY_INLINE Vec<D> LastOfLanes(D d, Vec<D> v,
189  TFromD<D>* HWY_RESTRICT buf) const {
190  const size_t N = Lanes(d);
191  Store(v, d, buf);
192  v = SetKey(d, buf + 0); // result must be broadcasted
193  for (size_t i = LanesPerKey(); i < N; i += LanesPerKey()) {
194  v = Last(d, v, SetKey(d, buf + i));
195  }
196  return v;
197  }
198 
199  // Same as for regular lanes because 128-bit lanes are u64.
200  template <class D>
201  HWY_INLINE Vec<D> FirstValue(D d) const {
202  return Set(d, hwy::LowestValue<TFromD<D> >());
203  }
204 
205  template <class D>
206  HWY_INLINE Vec<D> LastValue(D d) const {
207  return Set(d, hwy::HighestValue<TFromD<D> >());
208  }
209 };
210 
211 struct OrderDescending128 : public Key128 {
212  using Order = SortDescending;
213 
214  template <typename T>
215  HWY_INLINE bool Compare1(const T* a, const T* b) {
216  return (a[1] == b[1]) ? b[0] < a[0] : b[1] < a[1];
217  }
218 
219  template <class D>
220  HWY_INLINE Mask<D> Compare(D d, Vec<D> a, Vec<D> b) const {
221  return Lt128(d, b, a);
222  }
223 
224  // Used by CompareTop
225  template <class V>
226  HWY_INLINE Mask<DFromV<V> > CompareLanes(V a, V b) const {
227  return Lt(b, a);
228  }
229 
230  template <class D>
231  HWY_INLINE Vec<D> First(D d, const Vec<D> a, const Vec<D> b) const {
232  return Max128(d, a, b);
233  }
234 
235  template <class D>
236  HWY_INLINE Vec<D> Last(D d, const Vec<D> a, const Vec<D> b) const {
237  return Min128(d, a, b);
238  }
239 
240  template <class D>
241  HWY_INLINE Vec<D> FirstOfLanes(D d, Vec<D> v,
242  TFromD<D>* HWY_RESTRICT buf) const {
243  const size_t N = Lanes(d);
244  Store(v, d, buf);
245  v = SetKey(d, buf + 0); // result must be broadcasted
246  for (size_t i = LanesPerKey(); i < N; i += LanesPerKey()) {
247  v = First(d, v, SetKey(d, buf + i));
248  }
249  return v;
250  }
251 
252  template <class D>
253  HWY_INLINE Vec<D> LastOfLanes(D d, Vec<D> v,
254  TFromD<D>* HWY_RESTRICT buf) const {
255  const size_t N = Lanes(d);
256  Store(v, d, buf);
257  v = SetKey(d, buf + 0); // result must be broadcasted
258  for (size_t i = LanesPerKey(); i < N; i += LanesPerKey()) {
259  v = Last(d, v, SetKey(d, buf + i));
260  }
261  return v;
262  }
263 
264  // Same as for regular lanes because 128-bit lanes are u64.
265  template <class D>
266  HWY_INLINE Vec<D> FirstValue(D d) const {
267  return Set(d, hwy::HighestValue<TFromD<D> >());
268  }
269 
270  template <class D>
271  HWY_INLINE Vec<D> LastValue(D d) const {
272  return Set(d, hwy::LowestValue<TFromD<D> >());
273  }
274 };
275 
276 // Shared code that depends on Order.
277 template <class Base>
278 class Traits128 : public Base {
279 #if HWY_TARGET <= HWY_AVX2
280  // Returns vector with only the top u64 lane valid. Useful when the next step
281  // is to replicate the mask anyway.
282  template <class D>
283  HWY_INLINE HWY_MAYBE_UNUSED Vec<D> CompareTop(D d, Vec<D> a, Vec<D> b) const {
284  const Base* base = static_cast<const Base*>(this);
285  const Vec<D> eqHL = VecFromMask(d, Eq(a, b));
286  const Vec<D> ltHL = VecFromMask(d, base->CompareLanes(a, b));
287  const Vec<D> ltLX = ShiftLeftLanes<1>(ltHL);
288  return OrAnd(ltHL, eqHL, ltLX);
289  }
290 
291  // We want to swap 2 u128, i.e. 4 u64 lanes, based on the 0 or FF..FF mask in
292  // the most-significant of those lanes (the result of CompareTop), so
293  // replicate it 4x. Only called for >= 256-bit vectors.
294  template <class V>
295  HWY_INLINE V ReplicateTop4x(V v) const {
296 #if HWY_TARGET <= HWY_AVX3
297  return V{_mm512_permutex_epi64(v.raw, _MM_SHUFFLE(3, 3, 3, 3))};
298 #else // AVX2
299  return V{_mm256_permute4x64_epi64(v.raw, _MM_SHUFFLE(3, 3, 3, 3))};
300 #endif
301  }
302 #endif
303 
304  public:
305  constexpr bool Is128() const { return true; }
306 
307  template <class D>
308  HWY_INLINE void Sort2(D d, Vec<D>& a, Vec<D>& b) const {
309  const Base* base = static_cast<const Base*>(this);
310 
311  const Vec<D> a_copy = a;
312  const auto lt = base->Compare(d, a, b);
313  a = IfThenElse(lt, a, b);
314  b = IfThenElse(lt, b, a_copy);
315  }
316 
317  // Conditionally swaps even-numbered lanes with their odd-numbered neighbor.
318  template <class D>
319  HWY_INLINE Vec<D> SortPairsDistance1(D d, Vec<D> v) const {
320  const Base* base = static_cast<const Base*>(this);
321  Vec<D> swapped = base->ReverseKeys2(d, v);
322 
323 #if HWY_TARGET <= HWY_AVX2
324  const Vec<D> select = ReplicateTop4x(CompareTop(d, v, swapped));
325  return IfVecThenElse(select, swapped, v);
326 #else
327  Sort2(d, v, swapped);
328  return base->OddEvenKeys(swapped, v);
329 #endif
330  }
331 
332  // Swaps with the vector formed by reversing contiguous groups of 4 keys.
333  template <class D>
334  HWY_INLINE Vec<D> SortPairsReverse4(D d, Vec<D> v) const {
335  const Base* base = static_cast<const Base*>(this);
336  Vec<D> swapped = base->ReverseKeys4(d, v);
337 
338  // Only specialize for AVX3 because this requires 512-bit vectors.
339 #if HWY_TARGET <= HWY_AVX3
340  const Vec512<uint64_t> outHx = CompareTop(d, v, swapped);
341  // Similar to ReplicateTop4x, we want to gang together 2 comparison results
342  // (4 lanes). They are not contiguous, so use permute to replicate 4x.
343  alignas(64) uint64_t kIndices[8] = {7, 7, 5, 5, 5, 5, 7, 7};
344  const Vec512<uint64_t> select =
345  TableLookupLanes(outHx, SetTableIndices(d, kIndices));
346  return IfVecThenElse(select, swapped, v);
347 #else
348  Sort2(d, v, swapped);
349  return base->OddEvenPairs(d, swapped, v);
350 #endif
351  }
352 
353  // Conditionally swaps lane 0 with 4, 1 with 5 etc.
354  template <class D>
355  HWY_INLINE Vec<D> SortPairsDistance4(D, Vec<D>) const {
356  // Only used by Merge16, which would require 2048 bit vectors (unsupported).
357  HWY_ASSERT(0);
358  }
359 };
360 
361 #endif // HWY_TARGET != HWY_SCALAR
362 
363 } // namespace detail
364 // NOLINTNEXTLINE(google-readability-namespace-comments)
365 } // namespace HWY_NAMESPACE
366 } // namespace hwy
368 
369 #endif // HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE
#define HWY_RESTRICT
Definition: base.h:63
#define HWY_INLINE
Definition: base.h:64
#define HWY_DASSERT(condition)
Definition: base.h:193
#define HWY_MAYBE_UNUSED
Definition: base.h:75
#define HWY_ASSERT(condition)
Definition: base.h:147
void Swap(T *a, T *b)
Definition: vqsort-inl.h:63
HWY_INLINE Vec128< T, N > IfThenElse(hwy::SizeTag< 1 >, Mask128< T, N > mask, Vec128< T, N > yes, Vec128< T, N > no)
Definition: x86_128-inl.h:680
d
Definition: rvv-inl.h:1656
HWY_API Vec128< T, N > OddEvenBlocks(Vec128< T, N >, Vec128< T, N > even)
Definition: arm_neon-inl.h:4038
HWY_API auto Lt(V a, V b) -> decltype(a==b)
Definition: arm_neon-inl.h:5252
HWY_API auto Eq(V a, V b) -> decltype(a==b)
Definition: arm_neon-inl.h:5244
HWY_API size_t Lanes(Simd< T, N, kPow2 > d)
Definition: arm_sve-inl.h:218
HWY_API Vec128< T, N > IfVecThenElse(Vec128< T, N > mask, Vec128< T, N > yes, Vec128< T, N > no)
Definition: arm_neon-inl.h:1505
HWY_API Vec128< T, N > VecFromMask(Simd< T, N, 0 > d, const Mask128< T, N > v)
Definition: arm_neon-inl.h:1681
HWY_API Vec128< T, N > TableLookupLanes(Vec128< T, N > v, Indices128< T, N > idx)
Definition: arm_neon-inl.h:3419
HWY_API void StoreU(const Vec128< uint8_t > v, Full128< uint8_t >, uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2224
HWY_API Vec128< T, N > SwapAdjacentBlocks(Vec128< T, N > v)
Definition: arm_neon-inl.h:4045
HWY_INLINE VFromD< D > Min128(D d, const VFromD< D > a, const VFromD< D > b)
Definition: arm_neon-inl.h:5203
svuint16_t Set(Simd< bfloat16_t, N, kPow2 > d, bfloat16_t arg)
Definition: arm_sve-inl.h:282
HWY_API Vec128< uint8_t > LoadU(Full128< uint8_t >, const uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2031
HWY_INLINE Mask128< T, N > Lt128(Simd< T, N, 0 > d, Vec128< T, N > a, Vec128< T, N > b)
Definition: arm_neon-inl.h:5172
HWY_API Vec128< T, N > OrAnd(Vec128< T, N > o, Vec128< T, N > a1, Vec128< T, N > a2)
Definition: arm_neon-inl.h:1498
HWY_API Vec128< T, N > ConcatUpperLower(Simd< T, N, 0 > d, Vec128< T, N > hi, Vec128< T, N > lo)
Definition: arm_neon-inl.h:3895
HWY_INLINE VFromD< D > Max128(D d, const VFromD< D > a, const VFromD< D > b)
Definition: arm_neon-inl.h:5208
HWY_API Indices128< T, N > SetTableIndices(Simd< T, N, 0 > d, const TI *idx)
Definition: arm_neon-inl.h:3413
HWY_API Vec128< T, N > LoadDup128(Simd< T, N, 0 > d, const T *const HWY_RESTRICT p)
Definition: arm_neon-inl.h:2217
N
Definition: rvv-inl.h:1656
HWY_API Vec128< T > ReverseBlocks(Full128< T >, const Vec128< T > v)
Definition: arm_neon-inl.h:4053
HWY_API void Store(Vec128< T, N > v, Simd< T, N, 0 > d, T *HWY_RESTRICT aligned)
Definition: arm_neon-inl.h:2397
const vfloat64m1_t v
Definition: rvv-inl.h:1656
Definition: aligned_allocator.h:27
constexpr HWY_API T LowestValue()
Definition: base.h:512
constexpr HWY_API T HighestValue()
Definition: base.h:525
#define HWY_NAMESPACE
Definition: set_macros-inl.h:80
Definition: traits128-inl.h:35
SortAscending Order
Definition: traits128-inl.h:36
HWY_INLINE bool Compare1(const T *a, const T *b)
Definition: traits128-inl.h:39
SortDescending Order
Definition: traits128-inl.h:45
HWY_INLINE bool Compare1(const T *a, const T *b)
Definition: traits128-inl.h:48
Definition: traits128-inl.h:54
constexpr bool Is128() const
Definition: traits128-inl.h:55
constexpr size_t LanesPerKey() const
Definition: traits128-inl.h:56
Definition: vqsort.h:35
Definition: vqsort.h:38
HWY_AFTER_NAMESPACE()
HWY_BEFORE_NAMESPACE()