NVBIO
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
wavelet_tree_inl.h
Go to the documentation of this file.
1 /*
2  * nvbio
3  * Copyright (c) 2011-2014, NVIDIA CORPORATION. All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  * * Redistributions of source code must retain the above copyright
8  * notice, this list of conditions and the following disclaimer.
9  * * Redistributions in binary form must reproduce the above copyright
10  * notice, this list of conditions and the following disclaimer in the
11  * documentation and/or other materials provided with the distribution.
12  * * Neither the name of the NVIDIA CORPORATION nor the
13  * names of its contributors may be used to endorse or promote products
14  * derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  */
27 
28 #pragma once
29 
30 namespace nvbio {
31 
32 template <typename symbol_type>
34 {
35  typedef symbol_type argument_type;
36  typedef symbol_type result_type;
37 
39  select_bit_functor(const uint32 _i) : i(_i) {}
40 
42  symbol_type operator() (const symbol_type c) const { return (c >> i) & 1u; }
43 
44  const uint32 i;
45 };
46 
47 // a private class used for ranking 1's at global indices, exclusively
48 //
49 template <typename WaveletTreeType, typename OccIterator>
51 {
52  typedef typename WaveletTreeType::index_type index_type;
53 
56 
57  // constructor
58  //
61  const WaveletTreeType _tree,
62  const OccIterator _occ) :
63  tree( _tree ), occ( _occ ) {}
64 
65  // unary transform operator
66  //
69  {
70  if (r == 0)
71  return 0;
72  else
73  {
74  const uint32 i = r-1; // return the number of ones at i = r-1
75  const index_type base_index = i / 32u; // the index of this occurrence counter block
76  const index_type base_occ = occ[ base_index ]; // the base occurrence counter for this block
77 
78  const uint32 word = tree.bits().stream()[ base_index ]; // the block's corresponding word
79  const uint32 word_occ = popc_nbit<1u>( word, 1u, ~i & 31u ); // the delta popcount in this word
80 
81  return base_occ + word_occ; // the final result
82  }
83  }
84 
85  const WaveletTreeType tree;
86  const OccIterator occ;
87 };
88 
89 namespace priv {
90 namespace wtree {
91 
92 typedef std::pair<uint64,uint64> index_range;
93 typedef std::pair<uint64,uint64> symbol_range;
94 
95 struct node_type
96 {
97  node_type() {}
99  const uint32 _level,
100  const uint32 _index,
101  const symbol_range _s_range,
102  const index_range _i_range) :
103  level( _level ),
104  index( _index ),
105  s_range( _s_range ),
106  i_range( _i_range ) {}
107 
112 };
113 
114 } // namespace wtree
115 } // namespace priv
116 
117 //
118 // build a Wavelet Tree out of a string: the output consists of a bit-string representing
119 // the different bit-planes of the output WaveleTree, and an array of offsets to the
120 // leaves, containing the sorted string symbols.
121 //
122 // \tparam string_iterator the string type: must provide a random access iterator
123 // interface as well as define a proper stream_traits<string_iterator>
124 // expansion; particularly, stream_traits<string_iterator>::SYMBOL_SIZE
125 // is used to infer the number of bits needed to represent the
126 // symbols in the string's alphabet
127 //
128 template <typename system_tag, typename string_iterator, typename index_type, typename symbol_type>
129 void setup(
130  const index_type string_len,
131  const string_iterator& string,
133 {
134  typedef typename WaveletTreeStorage<system_tag,index_type>::plain_view_type wavelet_tree_view_type;
135  typedef typename WaveletTreeStorage<system_tag,index_type>::bit_iterator bit_iterator;
136  typedef typename WaveletTreeStorage<system_tag,index_type>::index_iterator occ_iterator;
137 
139 
140  // resize the output tree
141  out_tree.resize( string_len, symbol_size );
142 
143  // allocate a temporary string used for sorting
144  nvbio::vector<system_tag,symbol_type> sorted_string( string_len * 2 );
145 
147 
148  if (equal<system_tag,device_tag>())
149  {
150  // copy the input to the temporary sorting string
151  thrust::copy( string, string + string_len, sorted_string.begin() );
152 
153  cuda::SortBuffers<symbol_type*> sort_buffers;
154  sort_buffers.keys[0] = raw_pointer( sorted_string );
155  sort_buffers.keys[1] = raw_pointer( sorted_string ) + string_len;
156 
157  cuda::SortEnactor sort_enactor;
158 
159  // loop through all bit planes in the range [0,symbol_size)
160  for (uint32 i = 0; i < symbol_size; ++i)
161  {
162  const uint32 bit = symbol_size - i - 1u;
163 
164  // copy the i-th bit-plane to the output
165  priv::device_assign(
166  string_len,
168  thrust::device_ptr<symbol_type>( sort_buffers.current_keys() ),
170  out_tree.bits() + string_len * i );
171 
172  // sort by the leading i+1 bits
173  sort_enactor.sort( string_len, sort_buffers, bit, symbol_size );
174  }
175 
176  // setup the pointer to the fully sorted string
177  sorted_keys = sort_buffers.selector ?
178  sorted_string.begin() + string_len :
179  sorted_string.begin();
180  }
181  else
182  {
183  // copy the input to the temporary sorting string
184  thrust::copy( string, string + string_len, sorted_string.begin() );
185 
186  // loop through all bit planes in the range [0,symbol_size)
187  for (uint32 i = 0; i < symbol_size; ++i)
188  {
189  const uint32 bit = symbol_size - i - 1u;
190 
191  // copy the i-th bit-plane to the output
193  string_len,
195  sorted_string.begin(),
197  out_tree.bits() + string_len * i );
198 
199  // extract the leading bits
200  nvbio::transform<system_tag>(
201  string_len,
202  sorted_string.begin(),
203  sorted_string.begin() + string_len,
204  leading_bits<symbol_type>( i + 1u ) );
205 
206  // sort by the leading i+1 bits
207  thrust::sort_by_key( sorted_string.begin() + string_len, sorted_string.begin() + string_len, sorted_string.begin() );
208  }
209 
210  // setup the pointer to the fully sorted string
211  sorted_keys = sorted_string.begin();
212  }
213 
214  //
215  // now compute the Wavelet Tree's leaf offsets, representing integer
216  // offsets to the symbol runs in the lexicographically sorted string:
217  // in practice, this is equivalent to the prefix-sum of the symbol frequencies;
218  // for example, if the input string is '0301133' over a 2-bit alphabet, the
219  // sorted string will be '0011333', and the output offsets will be [0,2,4,4,7],
220  // as there are two 0's, two 1's, zero 2's, and three 3's.
221  //
222 
223  const uint32 n_symbols = 1u << symbol_size;
224 
225  nvbio::vector<system_tag,index_type> leaves( n_symbols + 1u );
226 
227  // the offset to the first node is always zero
228  leaves[0] = 0;
229 
230  // now find the offset to each symbol greater than zero
231  nvbio::upper_bound<system_tag>(
232  n_symbols,
233  thrust::make_counting_iterator<index_type>(0u),
234  string_len,
235  sorted_keys,
236  leaves.begin() + 1u );
237 
238  // build the binary heap tree of split points
239  //
240  // e.g. in the example above, this will be: [4, 2, 4]
241  // corresponding to the splits at symbols: 2, 1, 3
242  // i.e:
243  // nodes[0] = leaves[ 2 = split[0,4] ]; // root
244  // nodes[1] = leaves[ 1 = split[0,2] ]; // 1st child
245  // nodes[2] = leaves[ 3 = split[2,4] ]; // 2nd child
246  //
247 
250  typedef priv::wtree::node_type node_type;
251 
252  std::stack<node_type> stack;
253  stack.push( node_type( 0u, 0u, std::make_pair( 0u, n_symbols ), std::make_pair( 0u, string_len ) ) );
254 
255  nvbio::vector<host_tag,index_type> h_splits( n_symbols - 1u );
256  nvbio::vector<host_tag,index_type> h_lookups( n_symbols - 1u );
257  nvbio::vector<host_tag,index_type> h_leaves( leaves );
258 
259  // visit all the nodes in the tree
260  while (!stack.empty())
261  {
262  // fetch the top of the stack
263  node_type node = stack.top();
264  stack.pop();
265 
266  index_range i_range = node.i_range;
267  symbol_range s_range = node.s_range;
268 
269  // check whether this is a leaf node
270  if (s_range.second - s_range.first == 1)
271  continue;
272 
273  // calculate the left and right child indices
274  const uint32 l_node_index = node.index*2u + 1u;
275  const uint32 r_node_index = node.index*2u + 2u;
276  const uint32 s_split = (s_range.second + s_range.first)/2u;
277  const uint32 i_split = h_leaves[ s_split ]; // # symbols preceeding 's_split
278 
279  // write out the split point for this node
280  h_splits[ node.index ] = index_type( i_split );
281 
282  // store the this node's beginning index in the global bit vector
283  h_lookups[ node.index ] = index_type( i_range.first + node.level * string_len );
284 
285  // push the children onto the stack
286  stack.push( node_type( node.level + 1u, r_node_index, std::make_pair( s_split, s_range.second ), std::make_pair( i_split, i_range.second ) ) );
287  stack.push( node_type( node.level + 1u, l_node_index, std::make_pair( s_range.first, s_split ), std::make_pair( i_range.first, i_split ) ) );
288  }
289 
290  // and copy it to the output
291  thrust::copy( h_splits.begin(), h_splits.begin() + n_symbols - 1u, out_tree.splits() );
292 
293  // build the rank dictionary structure
294  {
295  // copy the nodes to the device
296  nvbio::vector<system_tag,index_type> lookups( h_lookups );
297  nvbio::vector<system_tag,uint8> temp_storage;
298 
299  typedef typename bit_iterator::storage_iterator words_iterator;
300 
301  const words_iterator words = out_tree.bits().stream();
302  const uint32 n_words = util::divide_ri( string_len * symbol_size, 32u );
303 
304  // compute the exclusive sum of the popcount of all the words in the bit-stream
306  n_words,
308  words,
310  out_tree.occ() + n_symbols,
311  thrust::plus<index_type>(),
312  0u,
313  temp_storage );
314 
316  plain_view( out_tree ),
317  out_tree.occ() + n_symbols );
318 
319  // and now build the tree structure by doing simple lookups
320  nvbio::transform<system_tag>(
321  n_symbols - 1u, // n
322  lookups.begin(), // input
323  out_tree.occ(), // output
324  ranker ); // functor
325 
326  out_tree.occ()[ n_symbols ] = 0u;
327  }
328 }
329 
330 // return the i-th symbol
331 //
332 template <typename BitStreamIterator, typename IndexIterator, typename SymbolType>
335 {
336  return text( *this, i );
337 }
338 
339 // return the number of bits set to b in the range [0,r] within node n at level l
340 //
341 template <typename BitStreamIterator, typename IndexIterator, typename SymbolType>
344 WaveletTree<BitStreamIterator,IndexIterator,SymbolType>::rank(const uint32 l, const uint32 node, const index_type node_begin, const index_type r, const uint8 b) const
345 {
346  const uint32 n_nodes = 1u << symbol_size();
347 
348  // the global index of the beginning of the node is given by its local index into its level,
349  // plus the global offset of the level into the bit string, which contains size() symbols
350  // per level
351  const uint32 global_node_begin = node_begin + l * size();
352 
353  const index_type ones = m_occ[ node ]; // # of occurrences of 1's preceding the node's beginning
354  const index_type zeros = global_node_begin - ones; // # of occurrences of 0's preceding the node's beginning
355  const index_type offset = b ? ones : zeros; // number of occurrences of b at the node's beginning
356 
357 
358  const index_type global_index = nvbio::min(
359  global_node_begin + r, // the global position of r in the bit-string
360  size() * (l+1u) - 1u ); // maximum index for this level
361  const index_type global_index_mod = ~global_index & 31u;
362 
363  const index_type base_index = global_index / 32u; // the index of this occurrence counter block
364  const index_type base_occ = b ? m_occ[ base_index + n_nodes ] :
365  base_index*32u - m_occ[ base_index + n_nodes ];
366 
367  const uint32 word = bits().stream()[ base_index ]; // the block's corresponding word
368  const uint32 word_occ = popc_nbit<1u>( word, b, global_index_mod );// the inclusive popcount
369 
370  return base_occ + word_occ - offset;
371 }
372 
373 // \relates WaveletTree
374 // fetch the number of occurrences of character c in the substring [0,i]
375 //
376 // \param dict the rank dictionary
377 // \param i the end of the query range [0,i]
378 // \param c the query character
379 //
380 template <typename BitStreamIterator, typename IndexIterator, typename SymbolType>
386  const uint32 c)
387 {
389 
390  const uint32 symbol_size = tree.symbol_size();
391 
392  // traverse the tree from the root node down to the leaf containing c
393  uint32 node = 0u;
394  index_type range_lo = 0u;
395  index_type range_hi = tree.size();
396 
397  index_type r = i+1;
398 
399  for (uint32 l = 0; l < symbol_size; ++l)
400  {
401  // we got to an empty node, the rank must be zero
402  if (range_lo == range_hi)
403  return 0u;
404 
405  // select the l-th level bit of c
406  const uint32 b = (c >> (symbol_size - l - 1u)) & 1u;
407 
408  // r is the new relative rank of c within the child node
409  r = r ? tree.rank( l, node, range_lo, r-1, b ) : 0u;
410 
411  // compute the base (i.e. left) child node
412  const uint32 child = node*2u + 1u;
413 
414  const uint32 split = tree.splits()[ node ];
415 
416  if (b == 1)
417  {
418  // descend into the right node
419  range_lo = split;
420  node = child + 1u;
421  }
422  else
423  {
424  // descend into the left node
425  range_hi = split;
426  node = child;
427  }
428  }
429  return r;
430 }
431 
432 // \relates WaveletTree
433 // fetch the number of occurrences of character c in the substring [0,i]
434 //
435 // \param dict the rank dictionary
436 // \param i the end of the query range [0,i]
437 // \param c the query character
438 //
439 template <typename BitStreamIterator, typename IndexIterator, typename SymbolType>
441 typename WaveletTree<BitStreamIterator,IndexIterator,SymbolType>::range_type
445  const uint32 c)
446 {
447  return make_vector(
448  rank( tree, range.x, c ),
449  rank( tree, range.y, c ) );
450 }
451 
452 // \relates WaveletTree
453 // fetch the text character at position i in the rank dictionary
454 //
455 template <typename BitStreamIterator, typename IndexIterator, typename IndexType, typename SymbolType>
457 SymbolType text(const WaveletTree<BitStreamIterator,IndexIterator,SymbolType>& tree, const IndexType i)
458 {
459  const uint32 symbol_size = tree.symbol_size();
460  const uint32 string_len = tree.size();
461 
462  // traverse the tree from the root node down to the leaf containing c
463  uint32 node = 0u;
464  IndexType range_lo = 0u;
465 
466  IndexType r = i;
467 
468  SymbolType c = 0;
469 
470  for (uint32 l = 0; l < symbol_size; ++l)
471  {
472  // read the character in position r at level l
473  const uint32 b = tree.bits()[ r + range_lo + string_len*l ];
474 
475  // insert b at the proper level in c
476  c |= b << (symbol_size - l - 1u);
477 
478  // r is the new relative rank of c within the child node
479  r = r ? tree.rank( l, node, range_lo, r-1, b ) : 0u;
480 
481  // compute the base (i.e. left) child node
482  const uint32 child = node*2u + 1u;
483 
484  if (b == 1)
485  {
486  // descend into the right node
487  range_lo = tree.splits()[ node ];
488  node = child + 1u;
489  }
490  else
491  {
492  // descend into the left node
493  node = child;
494  }
495  }
496  return c;
497 }
498 
499 } // namespace nvbio