MatchLib
comptrees.h
1 /*
2  * Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License")
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef _COMP_TREES_H_
17 #define _COMP_TREES_H_
18 
19 /* Minmax and search tree implementations using template metaprogramming.
20  *
21  * These implementations are written for efficient HLS in mind, not for ease of
22  * programming.
23  */
24 
25 #include <iostream>
26 #include <nvhls_marshaller.h>
27 
28 using namespace std;
29 
71 template <typename ArrT, typename ElemT, typename IdxT, bool is_max,
72  unsigned Width>
73 class Minmax {
74  public:
75  static IdxT minmax(ArrT inputs, IdxT start, IdxT end) {
76  const unsigned TotalElems = Wrapped<ArrT>::width / Wrapped<ElemT>::width;
77  const unsigned ElemWidth = Wrapped<ElemT>::width;
78  ElemT temp[TotalElems];
79 #pragma hls_unroll yes
80  for (unsigned i = 0; i < TotalElems; i++)
81  temp[i] = inputs.range((i + 1) * ElemWidth - 1, i * ElemWidth);
82 
83  IdxT upper_branch = Minmax<ArrT, ElemT, IdxT, is_max, Width / 2>::minmax(
84  inputs, (start + end) / 2 + 1, end);
85  IdxT lower_branch = Minmax<ArrT, ElemT, IdxT, is_max, Width / 2>::minmax(
86  inputs, start, (start + end) / 2);
87  if (is_max)
88  return temp[upper_branch] > temp[lower_branch] ? upper_branch
89  : lower_branch;
90  else
91  return temp[upper_branch] < temp[lower_branch] ? upper_branch
92  : lower_branch;
93  }
94 };
95 
96 /* Base condition for MPL minmax. */
97 template <typename ArrT, typename ElemT, typename IdxT, bool is_max>
98 class Minmax<ArrT, ElemT, IdxT, is_max, 2> {
99  public:
100  static IdxT minmax(ArrT inputs, IdxT start, IdxT end) {
101  const unsigned TotalElems = Wrapped<ArrT>::width / Wrapped<ElemT>::width;
102  const unsigned ElemWidth = Wrapped<ElemT>::width;
103  ElemT temp[TotalElems];
104 #pragma hls_unroll yes
105  for (unsigned i = 0; i < TotalElems; i++)
106  temp[i] = inputs.range((i + 1) * ElemWidth - 1, i * ElemWidth);
107 
108  if (is_max)
109  return temp[start] > temp[end] ? start : end;
110  else
111  return temp[start] < temp[end] ? start : end;
112  }
113 };
114 
115 // Base condition for Width = 1.
116 template <typename ArrT, typename ElemT, typename IdxT, bool is_max>
117 class Minmax<ArrT, ElemT, IdxT, is_max, 1> {
118  public:
119  // Comparing an element to itself must always just return itself.
120  static IdxT minmax(ArrT inputs, IdxT start, IdxT end) { return start; }
121 };
122 
123 /* Compile time SC datatype concatenator tree.
124  *
125  * This class can be used to concatenate NumElements number of SC datatype
126  * objects
127  * together. This is required to avoid using a for
128  * loop on the concat() method. The objects are stored in an array.
129  *
130  * Template args:
131  * BaseTemplate: the underlying class storing each element (e.g. sc_int,
132  * sc_uint).
133  * ElemW: The width of each element to be concatenated.
134  * NumElements: The number of elements to be concatenated.
135  */
136 template <template <int> class BaseTemplate, unsigned ElemW,
137  unsigned NumElements>
138 class Concat {
139  public:
140  typedef BaseTemplate<ElemW * NumElements> concat_t;
141  typedef BaseTemplate<ElemW * NumElements / 2> concat_half_t;
142 
143  static concat_t concat(BaseTemplate<ElemW> components[], unsigned start,
144  unsigned end) {
145  BaseTemplate<ElemW*(NumElements - 1)> retval =
147  start + 1, end);
148  return (retval, components[start]);
149  }
150 };
151 
152 // Base case for concatenating 2 elements
153 template <template <int> class BaseTemplate, unsigned ElemW>
154 class Concat<BaseTemplate, ElemW, 2> {
155  public:
156  typedef BaseTemplate<ElemW * 2> concat_t;
157 
158  static concat_t concat(BaseTemplate<ElemW> components[], unsigned start,
159  unsigned end) {
160  return (components[end], components[start]);
161  }
162 };
163 
164 // Base case for "concatenating" 1 element
165 template <template <int> class BaseTemplate, unsigned ElemW>
166 class Concat<BaseTemplate, ElemW, 1> {
167  public:
168  typedef BaseTemplate<ElemW> concat_t;
169 
170  static concat_t concat(BaseTemplate<ElemW> components[], unsigned start,
171  unsigned end) {
172  return components[start];
173  }
174 };
175 
176 /* Priority Encoder:
177  * Find first (starting from LSB) logic false or logic true in a bitvector.
178  * Returns bit position of first true/false if found.
179  * Returns -1 if desired logic value not found in bitvector.
180  *
181  * Template parameters:
182  * VecT: Bitvector type.
183  * ValT: Value type.
184  * IdxT: type of the return variable (should be size log2(Width)+1)
185  * Width: The size of the range to search in the array. Used to identify the
186  * base
187  * condition for MPL recursion.
188  */
189 
190 // Base condition
191 template <typename VecT, typename ValT, typename IdxT, unsigned Width>
192 class PriEnc {
193  public:
194  // This is the pri_enc function to call:
195  static IdxT val(VecT inputs, ValT comp_value) {
196  // This call will expand into some compile-time TMP recursion:
197  IdxT retval =
198  PriEnc<VecT, ValT, IdxT, Width>::val(inputs, comp_value, 0, Width - 1);
199  return retval;
200  }
201 
202  static IdxT val(VecT inputs, ValT comp_value, IdxT start, IdxT end) {
203  if (inputs[start] == comp_value) {
204  return start;
205  } else if (start == end) {
206  return -1;
207  } else {
209  inputs, comp_value, start + 1, start + Width - 1);
210  return retval;
211  }
212  }
213 };
214 
215 // Specialized cases for Width = 1
216 template <typename VecT, typename ValT, typename IdxT>
217 class PriEnc<VecT, ValT, IdxT, 1> {
218  public:
219  static IdxT val(VecT inputs, ValT comp_value) {
220  if (inputs[0] == comp_value) {
221  return 0;
222  } else {
223  return -1;
224  }
225  }
226 
227  static IdxT val(VecT inputs, ValT comp_value, IdxT start, IdxT end) {
228  if (inputs[end] == comp_value) {
229  return end;
230  } else {
231  return -1;
232  }
233  }
234 };
235 
236 #endif
STL namespace.
Compile-time minmax tree.
Definition: comptrees.h:73