Grok  9.5.0
generic_ops-inl.h
Go to the documentation of this file.
1 // Copyright 2021 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Target-independent types/functions defined after target-specific ops.
16 
17 // Relies on the external include guard in highway.h.
19 namespace hwy {
20 namespace HWY_NAMESPACE {
21 
22 // The lane type of a vector type, e.g. float for Vec<Simd<float, 4>>.
23 template <class V>
24 using LaneType = decltype(GetLane(V()));
25 
26 // Vector type, e.g. Vec128<float> for Simd<float, 4>. Useful as the return type
27 // of functions that do not take a vector argument, or as an argument type if
28 // the function only has a template argument for D, or for explicit type names
29 // instead of auto. This may be a built-in type.
30 template <class D>
31 using Vec = decltype(Zero(D()));
32 
33 // Mask type. Useful as the return type of functions that do not take a mask
34 // argument, or as an argument type if the function only has a template argument
35 // for D, or for explicit type names instead of auto.
36 template <class D>
37 using Mask = decltype(MaskFromVec(Zero(D())));
38 
39 // Returns the closest value to v within [lo, hi].
40 template <class V>
41 HWY_API V Clamp(const V v, const V lo, const V hi) {
42  return Min(Max(lo, v), hi);
43 }
44 
45 // CombineShiftRightBytes (and -Lanes) are not available for the scalar target,
46 // and RVV has its own implementation of -Lanes.
47 #if HWY_TARGET != HWY_SCALAR && HWY_TARGET != HWY_RVV
48 
49 template <size_t kLanes, class D, class V = VFromD<D>>
50 HWY_API V CombineShiftRightLanes(D d, const V hi, const V lo) {
51  constexpr size_t kBytes = kLanes * sizeof(LaneType<V>);
52  static_assert(kBytes < 16, "Shift count is per-block");
53  return CombineShiftRightBytes<kBytes>(d, hi, lo);
54 }
55 
56 // DEPRECATED
57 template <size_t kLanes, class V>
58 HWY_API V CombineShiftRightLanes(const V hi, const V lo) {
59  return CombineShiftRightLanes<kLanes>(DFromV<V>(), hi, lo);
60 }
61 
62 #endif
63 
64 // Returns lanes with the most significant bit set and all other bits zero.
65 template <class D>
67  using Unsigned = MakeUnsigned<TFromD<D>>;
68  const Unsigned bit = Unsigned(1) << (sizeof(Unsigned) * 8 - 1);
69  return BitCast(d, Set(Rebind<Unsigned, D>(), bit));
70 }
71 
72 // Returns quiet NaN.
73 template <class D>
75  const RebindToSigned<D> di;
76  // LimitsMax sets all exponent and mantissa bits to 1. The exponent plus
77  // mantissa MSB (to indicate quiet) would be sufficient.
78  return BitCast(d, Set(di, LimitsMax<TFromD<decltype(di)>>()));
79 }
80 
81 // ------------------------------ AESRound
82 
83 // Cannot implement on scalar: need at least 16 bytes for TableLookupBytes.
84 #if HWY_TARGET != HWY_SCALAR
85 
86 // Define for white-box testing, even if native instructions are available.
87 namespace detail {
88 
89 // Constant-time: computes inverse in GF(2^4) based on "Accelerating AES with
90 // Vector Permute Instructions" and the accompanying assembly language
91 // implementation: https://crypto.stanford.edu/vpaes/vpaes.tgz. See also Botan:
92 // https://botan.randombit.net/doxygen/aes__vperm_8cpp_source.html .
93 //
94 // A brute-force 256 byte table lookup can also be made constant-time, and
95 // possibly competitive on NEON, but this is more performance-portable
96 // especially for x86 and large vectors.
97 template <class V> // u8
98 HWY_INLINE V SubBytes(V state) {
99  const DFromV<V> du;
100  const auto mask = Set(du, 0xF);
101 
102  // Change polynomial basis to GF(2^4)
103  {
104  alignas(16) static constexpr uint8_t basisL[16] = {
105  0x00, 0x70, 0x2A, 0x5A, 0x98, 0xE8, 0xB2, 0xC2,
106  0x08, 0x78, 0x22, 0x52, 0x90, 0xE0, 0xBA, 0xCA};
107  alignas(16) static constexpr uint8_t basisU[16] = {
108  0x00, 0x4D, 0x7C, 0x31, 0x7D, 0x30, 0x01, 0x4C,
109  0x81, 0xCC, 0xFD, 0xB0, 0xFC, 0xB1, 0x80, 0xCD};
110  const auto sL = And(state, mask);
111  const auto sU = ShiftRight<4>(state); // byte shift => upper bits are zero
112  const auto gf4L = TableLookupBytes(LoadDup128(du, basisL), sL);
113  const auto gf4U = TableLookupBytes(LoadDup128(du, basisU), sU);
114  state = Xor(gf4L, gf4U);
115  }
116 
117  // Inversion in GF(2^4). Elements 0 represent "infinity" (division by 0) and
118  // cause TableLookupBytesOr0 to return 0.
119  alignas(16) static constexpr uint8_t kZetaInv[16] = {
120  0x80, 7, 11, 15, 6, 10, 4, 1, 9, 8, 5, 2, 12, 14, 13, 3};
121  alignas(16) static constexpr uint8_t kInv[16] = {
122  0x80, 1, 8, 13, 15, 6, 5, 14, 2, 12, 11, 10, 9, 3, 7, 4};
123  const auto tbl = LoadDup128(du, kInv);
124  const auto sL = And(state, mask); // L=low nibble, U=upper
125  const auto sU = ShiftRight<4>(state); // byte shift => upper bits are zero
126  const auto sX = Xor(sU, sL);
127  const auto invL = TableLookupBytes(LoadDup128(du, kZetaInv), sL);
128  const auto invU = TableLookupBytes(tbl, sU);
129  const auto invX = TableLookupBytes(tbl, sX);
130  const auto outL = Xor(sX, TableLookupBytesOr0(tbl, Xor(invL, invU)));
131  const auto outU = Xor(sU, TableLookupBytesOr0(tbl, Xor(invL, invX)));
132 
133  // Linear skew (cannot bake 0x63 bias into the table because out* indices
134  // may have the infinity flag set).
135  alignas(16) static constexpr uint8_t kAffineL[16] = {
136  0x00, 0xC7, 0xBD, 0x6F, 0x17, 0x6D, 0xD2, 0xD0,
137  0x78, 0xA8, 0x02, 0xC5, 0x7A, 0xBF, 0xAA, 0x15};
138  alignas(16) static constexpr uint8_t kAffineU[16] = {
139  0x00, 0x6A, 0xBB, 0x5F, 0xA5, 0x74, 0xE4, 0xCF,
140  0xFA, 0x35, 0x2B, 0x41, 0xD1, 0x90, 0x1E, 0x8E};
141  const auto affL = TableLookupBytesOr0(LoadDup128(du, kAffineL), outL);
142  const auto affU = TableLookupBytesOr0(LoadDup128(du, kAffineU), outU);
143  return Xor(Xor(affL, affU), Set(du, 0x63));
144 }
145 
146 } // namespace detail
147 
148 #endif // HWY_TARGET != HWY_SCALAR
149 
150 // "Include guard": skip if native AES instructions are available.
151 #if (defined(HWY_NATIVE_AES) == defined(HWY_TARGET_TOGGLE))
152 #ifdef HWY_NATIVE_AES
153 #undef HWY_NATIVE_AES
154 #else
155 #define HWY_NATIVE_AES
156 #endif
157 
158 // (Must come after HWY_TARGET_TOGGLE, else we don't reset it for scalar)
159 #if HWY_TARGET != HWY_SCALAR
160 
161 namespace detail {
162 
163 template <class V> // u8
164 HWY_API V ShiftRows(const V state) {
165  const DFromV<V> du;
166  alignas(16) static constexpr uint8_t kShiftRow[16] = {
167  0, 5, 10, 15, // transposed: state is column major
168  4, 9, 14, 3, //
169  8, 13, 2, 7, //
170  12, 1, 6, 11};
171  const auto shift_row = LoadDup128(du, kShiftRow);
172  return TableLookupBytes(state, shift_row);
173 }
174 
175 template <class V> // u8
176 HWY_API V MixColumns(const V state) {
177  const DFromV<V> du;
178  // For each column, the rows are the sum of GF(2^8) matrix multiplication by:
179  // 2 3 1 1 // Let s := state*1, d := state*2, t := state*3.
180  // 1 2 3 1 // d are on diagonal, no permutation needed.
181  // 1 1 2 3 // t1230 indicates column indices of threes for the 4 rows.
182  // 3 1 1 2 // We also need to compute s2301 and s3012 (=1230 o 2301).
183  alignas(16) static constexpr uint8_t k2301[16] = {
184  2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13};
185  alignas(16) static constexpr uint8_t k1230[16] = {
186  1, 2, 3, 0, 5, 6, 7, 4, 9, 10, 11, 8, 13, 14, 15, 12};
187  const RebindToSigned<decltype(du)> di; // can only do signed comparisons
188  const auto msb = Lt(BitCast(di, state), Zero(di));
189  const auto overflow = BitCast(du, IfThenElseZero(msb, Set(di, 0x1B)));
190  const auto d = Xor(Add(state, state), overflow); // = state*2 in GF(2^8).
191  const auto s2301 = TableLookupBytes(state, LoadDup128(du, k2301));
192  const auto d_s2301 = Xor(d, s2301);
193  const auto t_s2301 = Xor(state, d_s2301); // t(s*3) = XOR-sum {s, d(s*2)}
194  const auto t1230_s3012 = TableLookupBytes(t_s2301, LoadDup128(du, k1230));
195  return Xor(d_s2301, t1230_s3012); // XOR-sum of 4 terms
196 }
197 
198 } // namespace detail
199 
200 template <class V> // u8
201 HWY_API V AESRound(V state, const V round_key) {
202  // Intel docs swap the first two steps, but it does not matter because
203  // ShiftRows is a permutation and SubBytes is independent of lane index.
204  state = detail::SubBytes(state);
205  state = detail::ShiftRows(state);
206  state = detail::MixColumns(state);
207  state = Xor(state, round_key); // AddRoundKey
208  return state;
209 }
210 
211 // Constant-time implementation inspired by
212 // https://www.bearssl.org/constanttime.html, but about half the cost because we
213 // use 64x64 multiplies and 128-bit XORs.
214 template <class V>
215 HWY_API V CLMulLower(V a, V b) {
216  const DFromV<V> d;
217  static_assert(IsSame<TFromD<decltype(d)>, uint64_t>(), "V must be u64");
218  const auto k1 = Set(d, 0x1111111111111111ULL);
219  const auto k2 = Set(d, 0x2222222222222222ULL);
220  const auto k4 = Set(d, 0x4444444444444444ULL);
221  const auto k8 = Set(d, 0x8888888888888888ULL);
222  const auto a0 = And(a, k1);
223  const auto a1 = And(a, k2);
224  const auto a2 = And(a, k4);
225  const auto a3 = And(a, k8);
226  const auto b0 = And(b, k1);
227  const auto b1 = And(b, k2);
228  const auto b2 = And(b, k4);
229  const auto b3 = And(b, k8);
230 
231  auto m0 = Xor(MulEven(a0, b0), MulEven(a1, b3));
232  auto m1 = Xor(MulEven(a0, b1), MulEven(a1, b0));
233  auto m2 = Xor(MulEven(a0, b2), MulEven(a1, b1));
234  auto m3 = Xor(MulEven(a0, b3), MulEven(a1, b2));
235  m0 = Xor(m0, Xor(MulEven(a2, b2), MulEven(a3, b1)));
236  m1 = Xor(m1, Xor(MulEven(a2, b3), MulEven(a3, b2)));
237  m2 = Xor(m2, Xor(MulEven(a2, b0), MulEven(a3, b3)));
238  m3 = Xor(m3, Xor(MulEven(a2, b1), MulEven(a3, b0)));
239  return Or(Or(And(m0, k1), And(m1, k2)), Or(And(m2, k4), And(m3, k8)));
240 }
241 
242 template <class V>
243 HWY_API V CLMulUpper(V a, V b) {
244  const DFromV<V> d;
245  static_assert(IsSame<TFromD<decltype(d)>, uint64_t>(), "V must be u64");
246  const auto k1 = Set(d, 0x1111111111111111ULL);
247  const auto k2 = Set(d, 0x2222222222222222ULL);
248  const auto k4 = Set(d, 0x4444444444444444ULL);
249  const auto k8 = Set(d, 0x8888888888888888ULL);
250  const auto a0 = And(a, k1);
251  const auto a1 = And(a, k2);
252  const auto a2 = And(a, k4);
253  const auto a3 = And(a, k8);
254  const auto b0 = And(b, k1);
255  const auto b1 = And(b, k2);
256  const auto b2 = And(b, k4);
257  const auto b3 = And(b, k8);
258 
259  auto m0 = Xor(MulOdd(a0, b0), MulOdd(a1, b3));
260  auto m1 = Xor(MulOdd(a0, b1), MulOdd(a1, b0));
261  auto m2 = Xor(MulOdd(a0, b2), MulOdd(a1, b1));
262  auto m3 = Xor(MulOdd(a0, b3), MulOdd(a1, b2));
263  m0 = Xor(m0, Xor(MulOdd(a2, b2), MulOdd(a3, b1)));
264  m1 = Xor(m1, Xor(MulOdd(a2, b3), MulOdd(a3, b2)));
265  m2 = Xor(m2, Xor(MulOdd(a2, b0), MulOdd(a3, b3)));
266  m3 = Xor(m3, Xor(MulOdd(a2, b1), MulOdd(a3, b0)));
267  return Or(Or(And(m0, k1), And(m1, k2)), Or(And(m2, k4), And(m3, k8)));
268 }
269 
270 #endif // HWY_NATIVE_AES
271 #endif // HWY_TARGET != HWY_SCALAR
272 
273 // "Include guard": skip if native POPCNT-related instructions are available.
274 #if (defined(HWY_NATIVE_POPCNT) == defined(HWY_TARGET_TOGGLE))
275 #ifdef HWY_NATIVE_POPCNT
276 #undef HWY_NATIVE_POPCNT
277 #else
278 #define HWY_NATIVE_POPCNT
279 #endif
280 
281 template <typename V, HWY_IF_LANES_ARE(uint8_t, V)>
283  constexpr DFromV<V> d;
284  HWY_ALIGN constexpr uint8_t kLookup[16] = {
285  0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
286  };
287  auto lo = And(v, Set(d, 0xF));
288  auto hi = ShiftRight<4>(v);
289  auto lookup = LoadDup128(Simd<uint8_t, HWY_MAX(16, MaxLanes(d))>(), kLookup);
290  return Add(TableLookupBytes(lookup, hi), TableLookupBytes(lookup, lo));
291 }
292 
293 template <typename V, HWY_IF_LANES_ARE(uint16_t, V)>
294 HWY_API V PopulationCount(V v) {
295  const DFromV<V> d;
296  Repartition<uint8_t, decltype(d)> d8;
297  auto vals = BitCast(d, PopulationCount(BitCast(d8, v)));
298  return Add(ShiftRight<8>(vals), And(vals, Set(d, 0xFF)));
299 }
300 
301 template <typename V, HWY_IF_LANES_ARE(uint32_t, V)>
302 HWY_API V PopulationCount(V v) {
303  const DFromV<V> d;
304  Repartition<uint16_t, decltype(d)> d16;
305  auto vals = BitCast(d, PopulationCount(BitCast(d16, v)));
306  return Add(ShiftRight<16>(vals), And(vals, Set(d, 0xFF)));
307 }
308 
309 #if HWY_CAP_INTEGER64
310 template <typename V, HWY_IF_LANES_ARE(uint64_t, V)>
311 HWY_API V PopulationCount(V v) {
312  const DFromV<V> d;
313  Repartition<uint32_t, decltype(d)> d32;
314  auto vals = BitCast(d, PopulationCount(BitCast(d32, v)));
315  return Add(ShiftRight<32>(vals), And(vals, Set(d, 0xFF)));
316 }
317 #endif
318 
319 #endif // HWY_NATIVE_POPCNT
320 
321 // NOLINTNEXTLINE(google-readability-namespace-comments)
322 } // namespace HWY_NAMESPACE
323 } // namespace hwy
#define HWY_MAX(a, b)
Definition: base.h:123
#define HWY_API
Definition: base.h:117
#define HWY_INLINE
Definition: base.h:59
HWY_AFTER_NAMESPACE()
HWY_BEFORE_NAMESPACE()
HWY_INLINE Mask128< T, N > Xor(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:879
HWY_INLINE Vec128< T, N > IfThenElseZero(hwy::SizeTag< 1 >, Mask128< T, N > mask, Vec128< T, N > yes)
Definition: x86_128-inl.h:672
HWY_INLINE Mask128< T, N > And(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:768
HWY_API Vec< D > SignBit(D d)
Definition: generic_ops-inl.h:66
svuint16_t Set(Simd< bfloat16_t, N > d, bfloat16_t arg)
Definition: arm_sve-inl.h:299
HWY_API uint8_t GetLane(const Vec128< uint8_t, 16 > v)
Definition: arm_neon-inl.h:744
HWY_API Vec128< T, N > PopulationCount(Vec128< T, N > v)
Definition: arm_neon-inl.h:1520
HWY_API auto Lt(V a, V b) -> decltype(a==b)
Definition: arm_neon-inl.h:5035
HWY_API Vec< D > NaN(D d)
Definition: generic_ops-inl.h:74
HWY_API Vec128< T, N > LoadDup128(Simd< T, N > d, const T *const HWY_RESTRICT p)
Definition: arm_neon-inl.h:2164
HWY_API Vec128< uint64_t, N > Min(const Vec128< uint64_t, N > a, const Vec128< uint64_t, N > b)
Definition: arm_neon-inl.h:1879
HWY_API Vec256< uint64_t > CLMulLower(Vec256< uint64_t > a, Vec256< uint64_t > b)
Definition: x86_256-inl.h:3495
HWY_API Vec128< uint64_t, N > Max(const Vec128< uint64_t, N > a, const Vec128< uint64_t, N > b)
Definition: arm_neon-inl.h:1917
HWY_API Mask128< T, N > MaskFromVec(const Vec128< T, N > v)
Definition: arm_neon-inl.h:1600
HWY_INLINE Vec128< uint64_t > MulOdd(Vec128< uint64_t > a, Vec128< uint64_t > b)
Definition: arm_neon-inl.h:3947
HWY_API Vec256< uint8_t > AESRound(Vec256< uint8_t > state, Vec256< uint8_t > round_key)
Definition: x86_256-inl.h:3483
HWY_API Vec128< int64_t > MulEven(Vec128< int32_t > a, Vec128< int32_t > b)
Definition: arm_neon-inl.h:3907
Rebind< MakeSigned< TFromD< D > >, D > RebindToSigned
Definition: shared-inl.h:147
HWY_API V Add(V a, V b)
Definition: arm_neon-inl.h:5000
HWY_API Vec256< uint64_t > CLMulUpper(Vec256< uint64_t > a, Vec256< uint64_t > b)
Definition: x86_256-inl.h:3506
decltype(GetLane(V())) LaneType
Definition: generic_ops-inl.h:24
HWY_API Vec128< T, N > And(const Vec128< T, N > a, const Vec128< T, N > b)
Definition: arm_neon-inl.h:1384
HWY_API Vec128< T, N > BitCast(Simd< T, N > d, Vec128< FromT, N *sizeof(T)/sizeof(FromT)> v)
Definition: arm_neon-inl.h:687
typename D::template Rebind< T > Rebind
Definition: shared-inl.h:144
HWY_API Vec128< T, N > Xor(const Vec128< T, N > a, const Vec128< T, N > b)
Definition: arm_neon-inl.h:1430
HWY_API V Clamp(const V v, const V lo, const V hi)
Definition: generic_ops-inl.h:41
decltype(detail::DeduceD()(V())) DFromV
Definition: arm_neon-inl.h:532
HWY_INLINE constexpr HWY_MAYBE_UNUSED size_t MaxLanes(Simd< T, N >)
Definition: shared-inl.h:194
typename D::template Repartition< T > Repartition
Definition: shared-inl.h:155
decltype(MaskFromVec(Zero(D()))) Mask
Definition: generic_ops-inl.h:37
HWY_API Vec128< TI > TableLookupBytes(const Vec128< T > bytes, const Vec128< TI > from)
Definition: arm_neon-inl.h:3957
HWY_API Vec128< T, N > Zero(Simd< T, N > d)
Definition: arm_neon-inl.h:710
HWY_API V CombineShiftRightLanes(const D d, const V hi, V lo)
Definition: rvv-inl.h:1562
typename D::T TFromD
Definition: shared-inl.h:140
HWY_API VI TableLookupBytesOr0(const V bytes, const VI from)
Definition: arm_neon-inl.h:4012
HWY_API Vec128< T, N > Or(const Vec128< T, N > a, const Vec128< T, N > b)
Definition: arm_neon-inl.h:1419
decltype(Zero(D())) Vec
Definition: generic_ops-inl.h:31
Definition: aligned_allocator.h:23
constexpr HWY_API bool IsSame()
Definition: base.h:260
typename detail::Relations< T >::Unsigned MakeUnsigned
Definition: base.h:521
constexpr T LimitsMax()
Definition: base.h:329
#define HWY_ALIGN
Definition: set_macros-inl.h:78
#define HWY_NAMESPACE
Definition: set_macros-inl.h:77
Definition: shared-inl.h:35