mirror of
https://github.com/mit-han-lab/tinyengine.git
synced 2025-10-23 18:59:47 +08:00
support patchbased inference and se block graph optimization
This commit is contained in:
@@ -165,6 +165,19 @@ void update_SGD(float learning_rate){\n"""
|
||||
|
||||
def _genPatchInference(self):
|
||||
schedule = self.MemSche
|
||||
|
||||
# Find out the first layer for normal infernece
|
||||
first_normal_op = None
|
||||
for i, op in enumerate(schedule.layer):
|
||||
layer_info = op.get_layer_info()
|
||||
if "is_patch" not in layer_info or not layer_info["is_patch"]:
|
||||
first_normal_op = op
|
||||
break # end of patch-based
|
||||
assert first_normal_op, "Cannot find the first op for normal inference."
|
||||
first_bufferstr_for_normal_inference = first_normal_op._getBufferstr(
|
||||
first_normal_op.params["input_buf_add"], first_normal_op.params["input_buf_add_offset"]
|
||||
)
|
||||
|
||||
layer_info = schedule.layer[0].get_layer_info()
|
||||
if "is_patch" in layer_info and layer_info["is_patch"]:
|
||||
fp = self.source_handle
|
||||
@@ -217,13 +230,13 @@ void update_SGD(float learning_rate){\n"""
|
||||
+ """;
|
||||
}
|
||||
/* load partial input from the img */
|
||||
q7_t* patch_input = &buffer0[0]; // for partial input
|
||||
int start_x = MAX("""
|
||||
q7_t* patch_input = getInput(); // for partial input
|
||||
int start_x = TN_MAX("""
|
||||
+ str(first_width - self.patch_params["pad_l"])
|
||||
+ """ * j - """
|
||||
+ str(self.patch_params["pad_l"])
|
||||
+ """,0);
|
||||
int start_y = MAX("""
|
||||
int start_y = TN_MAX("""
|
||||
+ str(first_height - self.patch_params["pad_l"])
|
||||
+ """ * i - """
|
||||
+ str(self.patch_params["pad_l"])
|
||||
@@ -255,7 +268,8 @@ void update_SGD(float learning_rate){\n"""
|
||||
}
|
||||
invoke_1patch(pad_t,pad_b,pad_l,pad_r);
|
||||
/* concat the output from buffer0 (this is set manually for now) */
|
||||
q7_t* output_ptr = buffer1 + (i * """
|
||||
q7_t* output_ptr = """
|
||||
+ f"{first_bufferstr_for_normal_inference} + (i * "
|
||||
+ str(patch_out_w)
|
||||
+ """ * """
|
||||
+ str(out_w)
|
||||
@@ -288,7 +302,7 @@ void update_SGD(float learning_rate){\n"""
|
||||
}
|
||||
}
|
||||
//stage 2
|
||||
invoke();
|
||||
invoke(NULL);
|
||||
}"""
|
||||
)
|
||||
string += """
|
||||
@@ -308,12 +322,6 @@ void invoke_1patch(uint16_t pad_t, uint16_t pad_b, uint16_t pad_l ,uint16_t pad_
|
||||
layercnt += 1
|
||||
fp.write(string)
|
||||
if layer_info["op"] == "CONV_2D":
|
||||
# hardcode this memory schedule for quick implementation
|
||||
# TODO: adjust this according to model architecture and split index
|
||||
next_layer_info = schedule.layer[i + 1].get_layer_info()
|
||||
if "is_patch" not in next_layer_info or not next_layer_info["is_patch"]:
|
||||
layer_info["output_buf_add"] = "front"
|
||||
layer_info["output_buf_add_offset"] = 0
|
||||
if self.unsigned_input:
|
||||
raise Exception("unsigned input is not supported by patch-based yet")
|
||||
|
||||
|
Reference in New Issue
Block a user