DOLFINx
DOLFINx C++ interface
sort.h
1// Copyright (C) 2021 Igor Baratta
2//
3// This file is part of DOLFINx (https://www.fenicsproject.org)
4//
5// SPDX-License-Identifier: LGPL-3.0-or-later
6
7#pragma once
8
9#include <algorithm>
10#include <bitset>
11#include <cstdint>
12#include <dolfinx/common/Timer.h>
13#include <numeric>
14#include <type_traits>
15#include <vector>
16#include <xtensor/xtensor.hpp>
17#include <xtensor/xview.hpp>
18#include <xtl/xspan.hpp>
19
20namespace dolfinx
21{
22
28template <typename T, int BITS = 8>
29void radix_sort(const xtl::span<T>& array)
30{
31 static_assert(std::is_integral<T>(), "This function only sorts integers.");
32
33 if (array.size() <= 1)
34 return;
35
36 T max_value = *std::max_element(array.begin(), array.end());
37
38 // Sort N bits at a time
39 constexpr int bucket_size = 1 << BITS;
40 T mask = (T(1) << BITS) - 1;
41
42 // Compute number of iterations, most significant digit (N bits) of
43 // maxvalue
44 int its = 0;
45 while (max_value)
46 {
47 max_value >>= BITS;
48 its++;
49 }
50
51 // Adjacency list arrays for computing insertion position
52 std::array<std::int32_t, bucket_size> counter;
53 std::array<std::int32_t, bucket_size + 1> offset;
54
55 std::int32_t mask_offset = 0;
56 std::vector<T> buffer(array.size());
57 xtl::span<T> current_perm = array;
58 xtl::span<T> next_perm = buffer;
59 for (int i = 0; i < its; i++)
60 {
61 // Zero counter array
62 std::fill(counter.begin(), counter.end(), 0);
63
64 // Count number of elements per bucket
65 for (T c : current_perm)
66 counter[(c & mask) >> mask_offset]++;
67
68 // Prefix sum to get the inserting position
69 offset[0] = 0;
70 std::partial_sum(counter.begin(), counter.end(), std::next(offset.begin()));
71 for (T c : current_perm)
72 {
73 std::int32_t bucket = (c & mask) >> mask_offset;
74 std::int32_t new_pos = offset[bucket + 1] - counter[bucket];
75 next_perm[new_pos] = c;
76 counter[bucket]--;
77 }
78
79 mask = mask << BITS;
80 mask_offset += BITS;
81
82 std::swap(current_perm, next_perm);
83 }
84
85 // Copy data back to array
86 if (its % 2 != 0)
87 std::copy(buffer.begin(), buffer.end(), array.begin());
88}
89
98template <typename T, int BITS = 16>
99void argsort_radix(const xtl::span<const T>& array,
100 xtl::span<std::int32_t> perm)
101{
102 static_assert(std::is_integral<T>::value, "Integral required.");
103
104 if (array.size() <= 1)
105 return;
106
107 const auto [min, max] = std::minmax_element(array.begin(), array.end());
108 T range = *max - *min + 1;
109
110 // Sort N bits at a time
111 constexpr int bucket_size = 1 << BITS;
112 T mask = (T(1) << BITS) - 1;
113 std::int32_t mask_offset = 0;
114
115 // Compute number of iterations, most significant digit (N bits) of
116 // maxvalue
117 int its = 0;
118 while (range)
119 {
120 range >>= BITS;
121 its++;
122 }
123
124 // Adjacency list arrays for computing insertion position
125 std::array<std::int32_t, bucket_size> counter;
126 std::array<std::int32_t, bucket_size + 1> offset;
127
128 std::vector<std::int32_t> perm2(perm.size());
129 xtl::span<std::int32_t> current_perm = perm;
130 xtl::span<std::int32_t> next_perm = perm2;
131 for (int i = 0; i < its; i++)
132 {
133 // Zero counter
134 std::fill(counter.begin(), counter.end(), 0);
135
136 // Count number of elements per bucket
137 for (auto cp : current_perm)
138 {
139 T value = array[cp] - *min;
140 std::int32_t bucket = (value & mask) >> mask_offset;
141 counter[bucket]++;
142 }
143
144 // Prefix sum to get the inserting position
145 offset[0] = 0;
146 std::partial_sum(counter.begin(), counter.end(), std::next(offset.begin()));
147
148 // Sort py permutation
149 for (auto cp : current_perm)
150 {
151 T value = array[cp] - *min;
152 std::int32_t bucket = (value & mask) >> mask_offset;
153 std::int32_t pos = offset[bucket + 1] - counter[bucket];
154 next_perm[pos] = cp;
155 counter[bucket]--;
156 }
157
158 std::swap(current_perm, next_perm);
159
160 mask = mask << BITS;
161 mask_offset += BITS;
162 }
163
164 if (its % 2 == 1)
165 std::copy(perm2.begin(), perm2.end(), perm.begin());
166}
167
177template <typename T, int BITS = 16>
178std::vector<std::int32_t> sort_by_perm(const xtl::span<const T>& x,
179 std::size_t shape1)
180{
181 static_assert(std::is_integral<T>::value, "Integral required.");
182 assert(shape1 > 0);
183 assert(x.size() % shape1 == 0);
184 const std::size_t shape0 = x.size() / shape1;
185 std::vector<std::int32_t> perm(shape0);
186 std::iota(perm.begin(), perm.end(), 0);
187
188 // Sort by each column, right to left. Col 0 has the most signficant
189 // "digit".
190 std::vector<T> column(shape0);
191 for (std::size_t i = 0; i < shape1; ++i)
192 {
193 int col = shape1 - 1 - i;
194 for (std::size_t j = 0; j < shape0; ++j)
195 column[j] = x[j * shape1 + col];
196 argsort_radix<T, BITS>(column, perm);
197 }
198
199 return perm;
200}
201
202} // namespace dolfinx