-
Notifications
You must be signed in to change notification settings - Fork 282
WIP: Add Elementwise Functions support #1500
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
WIP: Add Elementwise Functions support #1500
Conversation
preusser
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @SpiritSeeker!
Please, review comments.
| @property | ||
| def cpp_op(self): | ||
| odt_hls_name = self.out_dtype.get_hls_datatype_str() | ||
| return "({0} > 0 ? (%s){0} : (%s)0)" % (odt_hls_name, odt_hls_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reversed comparison {0} < 0 is easier for most datatypes.
| inp_bw = self.inp_dtype.bitwidth() | ||
| # The output would be unsigned with same bit-width as input | ||
| # if input was unsigned, else one bit less | ||
| out_bw = inp_bw - 1 if self.inp_dtype.signed() else inp_bw |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Consider issuing a warning when constructing a
ElementwiseReLUnode with an unsigned input type. - You can only safely strip a bit from the output datatype if the input datatype is narrow, i.e. within
[-2^(n-1) + 1 : 2^(n-1) - 1].
| odt_hls_name = self.out_dtype.get_hls_datatype_str() | ||
| # Explicitly use the overloads, using hls::exp results in minor errors | ||
| if self.out_dtype.get_canonical_name() == "FLOAT32": | ||
| return "(hls::expf((%s){0}))" % (odt_hls_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return "hls::exp(%s({0}))" % (odt_hls_name) should be the only return statement. Rely on function overload selection by the argument type for specialization.
| odt_hls_name = self.out_dtype.get_hls_datatype_str() | ||
| # Explicitly use the overloads, using hls::erf results in minor errors | ||
| if self.out_dtype.get_canonical_name() == "FLOAT32": | ||
| return "(hls::erff((%s){0}))" % (odt_hls_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return "hls::erf(%s({0}))" % (odt_hls_name) should be the only return statement. Rely on function overload selection by the argument type for specialization.
| # Generates C++ code for declaring all streams involved in C++ simulation | ||
| # for testing | ||
| def strm_decl(self): | ||
| # Allways add the output stream to the declarations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not concise?:
self.code_gen_dict["$STREAMDECLARATIONS$"] = [
# Note: Assumes stream type aliases to be set in defines
"OutStream out0_V;",
"InpStream in0_V;"
]
| #pragma HLS BIND_STORAGE variable=out type=RAM_S2P impl=LUTRAM | ||
| """, | ||
| # Perfect loop nest over all folded output dimensions | ||
| *[for_loop(dim, size) + " {" for dim, size in enumerate(out_shape)], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to be equivalent to a single flat loop using the product over all dimensions as its bound.
|
|
||
| # Add HLS interface directives specifying how to create RTL ports for | ||
| # the top-level function arguments | ||
| self.code_gen_dict["$PRAGMAS$"] += [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fuse into a single compact append for both lines of code.
| def get_verilog_top_module_intf_names(self): | ||
| # Start collecting interface names in a dictionary starting with clock | ||
| # and reset | ||
| intf_names = {"clk": ["ap_clk"], "rst": ["ap_rst_n"]} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pick up all the other associations as part of the initialization.
This PR adds support for elementwise functions (elementwise unary operations), allowing for easy extension to all hls math functions from HLS Math Library. The current PR adds support for ReLU, elementwise Exp, and elementwise Erf functions.
Status of tests:
✔️ ReLU - FLOAT32, FLOAT16, INT, FIXED (cppsim and rtlsim)
✔️ Exp and Erf - cppsim + FLOAT32
✔️ Exp - rtlsim: FLOAT32 and FLOAT16
✖️ Exp and Erf - cppsim + FLOAT16 (numpy and simulated results differ significantly)
✖️ Erf - rtlsim: FLOAT32 and FLOAT16 (RTL watchdog timeout error)