support patchbased inference and se block graph optimization

This commit is contained in:
meenchen
2022-12-10 00:44:55 -05:00
parent b5d75b0e61
commit f7b7f4fd5a
9 changed files with 511 additions and 233 deletions

View File

@@ -165,6 +165,19 @@ void update_SGD(float learning_rate){\n"""
def _genPatchInference(self):
schedule = self.MemSche
# Find out the first layer for normal infernece
first_normal_op = None
for i, op in enumerate(schedule.layer):
layer_info = op.get_layer_info()
if "is_patch" not in layer_info or not layer_info["is_patch"]:
first_normal_op = op
break # end of patch-based
assert first_normal_op, "Cannot find the first op for normal inference."
first_bufferstr_for_normal_inference = first_normal_op._getBufferstr(
first_normal_op.params["input_buf_add"], first_normal_op.params["input_buf_add_offset"]
)
layer_info = schedule.layer[0].get_layer_info()
if "is_patch" in layer_info and layer_info["is_patch"]:
fp = self.source_handle
@@ -217,13 +230,13 @@ void update_SGD(float learning_rate){\n"""
+ """;
}
/* load partial input from the img */
q7_t* patch_input = &buffer0[0]; // for partial input
int start_x = MAX("""
q7_t* patch_input = getInput(); // for partial input
int start_x = TN_MAX("""
+ str(first_width - self.patch_params["pad_l"])
+ """ * j - """
+ str(self.patch_params["pad_l"])
+ """,0);
int start_y = MAX("""
int start_y = TN_MAX("""
+ str(first_height - self.patch_params["pad_l"])
+ """ * i - """
+ str(self.patch_params["pad_l"])
@@ -255,7 +268,8 @@ void update_SGD(float learning_rate){\n"""
}
invoke_1patch(pad_t,pad_b,pad_l,pad_r);
/* concat the output from buffer0 (this is set manually for now) */
q7_t* output_ptr = buffer1 + (i * """
q7_t* output_ptr = """
+ f"{first_bufferstr_for_normal_inference} + (i * "
+ str(patch_out_w)
+ """ * """
+ str(out_w)
@@ -288,7 +302,7 @@ void update_SGD(float learning_rate){\n"""
}
}
//stage 2
invoke();
invoke(NULL);
}"""
)
string += """
@@ -308,12 +322,6 @@ void invoke_1patch(uint16_t pad_t, uint16_t pad_b, uint16_t pad_l ,uint16_t pad_
layercnt += 1
fp.write(string)
if layer_info["op"] == "CONV_2D":
# hardcode this memory schedule for quick implementation
# TODO: adjust this according to model architecture and split index
next_layer_info = schedule.layer[i + 1].get_layer_info()
if "is_patch" not in next_layer_info or not next_layer_info["is_patch"]:
layer_info["output_buf_add"] = "front"
layer_info["output_buf_add_offset"] = 0
if self.unsigned_input:
raise Exception("unsigned input is not supported by patch-based yet")