2022-11-20 16:52:33 -05:00

61 lines
2.0 KiB
Python

import warnings
from .basic_utils import basicOperator, deep_copy_dicts, overwrite_dicts
__all__ = ["bias_update"]
default_params = {
# op related
"op": "BIAS_UPDATE",
"input_idx": None,
"output_idx": None,
# tensor related
"input_size": None,
"input_buf_add": None,
"input_buf_add_offset": None,
"QAS": None,
"input_dtype": "float32",
"output_dtype": "float32",
}
class bias_update(basicOperator):
def __init__(self, params: dict) -> None:
self.params = deep_copy_dicts(default_params)
overwrite_dicts(self.params, params)
super().__init__()
self._add_input(
self.params["input_idx"],
self.params["input_dtype"],
self.params["input_size"],
1,
1,
)
if None in default_params:
warnings.warn(f"parameters are not all set for op {self.params['op']}")
def generate_inference_str(self):
params = self.params
# generate string
string = f"const float {self.params['output_idx']}_QAS[{self.params['input_size']}] = " + "{"
for i in 1 / self.params["QAS"]:
string += str(i) + ","
string += "};\n"
input_str = self._getBufferstr(params["input_buf_add"], params["input_buf_add_offset"])
if params["input_dtype"] == "float32":
string += (
f"fptr = (float*){input_str};//gradients\n"
+ f"for(int i = 0; i < {self.params['input_size']}; i++) {self.params['output_idx']}[i] -="
+ f" (int32_t)(fptr[i] * blr * {self.params['output_idx']}_QAS[i]);\n"
)
elif params["input_dtype"] == "int32":
string += (
f"int32ptr = (int32_t*){input_str};//gradients\n"
+ f"for(int i = 0; i < {self.params['input_size']}; i++) {self.params['output_idx']}[i] -="
+ f" (int32_t)(int32ptr[i] * blr * {self.params['output_idx']}_QAS[i]);\n"
)
else:
raise NotImplementedError
return string