MatchLib
All Classes Namespaces Files Functions Modules Pages
comptrees.h
1/*
2 * Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License")
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef _COMP_TREES_H_
17#define _COMP_TREES_H_
18
19/* Minmax and search tree implementations using template metaprogramming.
20 *
21 * These implementations are written for efficient HLS in mind, not for ease of
22 * programming.
23 */
24
25#include <iostream>
26#include <nvhls_marshaller.h>
27
28using namespace std;
29
71template <typename ArrT, typename ElemT, typename IdxT, bool is_max,
72 unsigned Width>
73class Minmax {
74 public:
75 static IdxT minmax(ArrT inputs, IdxT start, IdxT end) {
76 const unsigned TotalElems = Wrapped<ArrT>::width / Wrapped<ElemT>::width;
77 const unsigned ElemWidth = Wrapped<ElemT>::width;
78 ElemT temp[TotalElems];
79#pragma hls_unroll yes
80 for (unsigned i = 0; i < TotalElems; i++)
81 temp[i] = inputs.range((i + 1) * ElemWidth - 1, i * ElemWidth);
82
83 IdxT upper_branch = Minmax<ArrT, ElemT, IdxT, is_max, Width / 2>::minmax(
84 inputs, (start + end) / 2 + 1, end);
85 IdxT lower_branch = Minmax<ArrT, ElemT, IdxT, is_max, Width / 2>::minmax(
86 inputs, start, (start + end) / 2);
87 if (is_max)
88 return temp[upper_branch] > temp[lower_branch] ? upper_branch
89 : lower_branch;
90 else
91 return temp[upper_branch] < temp[lower_branch] ? upper_branch
92 : lower_branch;
93 }
94};
95
96/* Base condition for MPL minmax. */
97template <typename ArrT, typename ElemT, typename IdxT, bool is_max>
98class Minmax<ArrT, ElemT, IdxT, is_max, 2> {
99 public:
100 static IdxT minmax(ArrT inputs, IdxT start, IdxT end) {
101 const unsigned TotalElems = Wrapped<ArrT>::width / Wrapped<ElemT>::width;
102 const unsigned ElemWidth = Wrapped<ElemT>::width;
103 ElemT temp[TotalElems];
104#pragma hls_unroll yes
105 for (unsigned i = 0; i < TotalElems; i++)
106 temp[i] = inputs.range((i + 1) * ElemWidth - 1, i * ElemWidth);
107
108 if (is_max)
109 return temp[start] > temp[end] ? start : end;
110 else
111 return temp[start] < temp[end] ? start : end;
112 }
113};
114
115// Base condition for Width = 1.
116template <typename ArrT, typename ElemT, typename IdxT, bool is_max>
117class Minmax<ArrT, ElemT, IdxT, is_max, 1> {
118 public:
119 // Comparing an element to itself must always just return itself.
120 static IdxT minmax(ArrT inputs, IdxT start, IdxT end) { return start; }
121};
122
123/* Compile time SC datatype concatenator tree.
124 *
125 * This class can be used to concatenate NumElements number of SC datatype
126 * objects
127 * together. This is required to avoid using a for
128 * loop on the concat() method. The objects are stored in an array.
129 *
130 * Template args:
131 * BaseTemplate: the underlying class storing each element (e.g. sc_int,
132 * sc_uint).
133 * ElemW: The width of each element to be concatenated.
134 * NumElements: The number of elements to be concatenated.
135 */
136template <template <int> class BaseTemplate, unsigned ElemW,
137 unsigned NumElements>
138class Concat {
139 public:
140 typedef BaseTemplate<ElemW * NumElements> concat_t;
141 typedef BaseTemplate<ElemW * NumElements / 2> concat_half_t;
142
143 static concat_t concat(BaseTemplate<ElemW> components[], unsigned start,
144 unsigned end) {
145 BaseTemplate<ElemW*(NumElements - 1)> retval =
147 start + 1, end);
148 return (retval, components[start]);
149 }
150};
151
152// Base case for concatenating 2 elements
153template <template <int> class BaseTemplate, unsigned ElemW>
154class Concat<BaseTemplate, ElemW, 2> {
155 public:
156 typedef BaseTemplate<ElemW * 2> concat_t;
157
158 static concat_t concat(BaseTemplate<ElemW> components[], unsigned start,
159 unsigned end) {
160 return (components[end], components[start]);
161 }
162};
163
164// Base case for "concatenating" 1 element
165template <template <int> class BaseTemplate, unsigned ElemW>
166class Concat<BaseTemplate, ElemW, 1> {
167 public:
168 typedef BaseTemplate<ElemW> concat_t;
169
170 static concat_t concat(BaseTemplate<ElemW> components[], unsigned start,
171 unsigned end) {
172 return components[start];
173 }
174};
175
176/* Priority Encoder:
177 * Find first (starting from LSB) logic false or logic true in a bitvector.
178 * Returns bit position of first true/false if found.
179 * Returns -1 if desired logic value not found in bitvector.
180 *
181 * Template parameters:
182 * VecT: Bitvector type.
183 * ValT: Value type.
184 * IdxT: type of the return variable (should be size log2(Width)+1)
185 * Width: The size of the range to search in the array. Used to identify the
186 * base
187 * condition for MPL recursion.
188 */
189
190// Base condition
191template <typename VecT, typename ValT, typename IdxT, unsigned Width>
192class PriEnc {
193 public:
194 // This is the pri_enc function to call:
195 static IdxT val(VecT inputs, ValT comp_value) {
196 // This call will expand into some compile-time TMP recursion:
197 IdxT retval =
198 PriEnc<VecT, ValT, IdxT, Width>::val(inputs, comp_value, 0, Width - 1);
199 return retval;
200 }
201
202 static IdxT val(VecT inputs, ValT comp_value, IdxT start, IdxT end) {
203 if (inputs[start] == comp_value) {
204 return start;
205 } else if (start == end) {
206 return -1;
207 } else {
209 inputs, comp_value, start + 1, start + Width - 1);
210 return retval;
211 }
212 }
213};
214
215// Specialized cases for Width = 1
216template <typename VecT, typename ValT, typename IdxT>
217class PriEnc<VecT, ValT, IdxT, 1> {
218 public:
219 static IdxT val(VecT inputs, ValT comp_value) {
220 if (inputs[0] == comp_value) {
221 return 0;
222 } else {
223 return -1;
224 }
225 }
226
227 static IdxT val(VecT inputs, ValT comp_value, IdxT start, IdxT end) {
228 if (inputs[end] == comp_value) {
229 return end;
230 } else {
231 return -1;
232 }
233 }
234};
235
236#endif
Compile-time minmax tree.
Definition comptrees.h:73