一、语义分割推理测试的一般流程
前处理
(1)get image
(2)letter_box
:o_h
,o_w
,i_h
,i_w
,n_h
,n_w
(3)1/250
,CHW
,BCHW
(4)to tensor
,to cuda
前向传播:get prediction
后处理:
(1)CHW
,HWC
,softmax
(2)to numpy
,to cpu
(3)letter_box_inv
:i_h
,n_h
,o_w
,o_h
(4)armax
(5)get image
:Image.fromarray(np.uint8())
二、代码讲解
前处理
(1)get image
image = Image.open(img)
(2)letter_box
:o_h
,o_w
,i_h
,i_w
,n_h
,n_w
orininal_h = np.array(image).shape[0]
orininal_w = np.array(image).shape[1]
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]))#---------------------------------------------------#
# 对输入图像进行resize
#---------------------------------------------------#
def resize_image(image, size):iw, ih = image.sizew, h = sizescale = min(w/iw, h/ih)nw = int(iw*scale)nh = int(ih*scale)image = image.resize((nw,nh), Image.BICUBIC)new_image = Image.new('RGB', size, (128,128,128))new_image.paste(image, ((w-nw)//2, (h-nh)//2))return new_image, nw, nh
(3)1/250
,CHW
,BCHW
#---------------------------------------------------------#
# 添加上batch_size维度
#---------------------------------------------------------#
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)def preprocess_input(image):image /= 255.0return image
(4)to tensor
,to cuda
with torch.no_grad():images = torch.from_numpy(image_data)if self.cuda:images = images.cuda()
前向传播:
get prediction
#---------------------------------------------------#
# 图片传入网络进行预测
#---------------------------------------------------#
pr = self.net(images)[0]
#---------------------------------------------------#
后处理:
(1)CHW
,HWC
,softmax
(2)to numpy
,to cpu
#---------------------------------------------------#
# 取出每一个像素点的种类
#---------------------------------------------------#
pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()
(3)letter_box_inv
:i_h
,n_h
,o_w
,o_h
#--------------------------------------#
# 将灰条部分截取掉
#--------------------------------------#
pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
#---------------------------------------------------#
# 进行图片的resize
#---------------------------------------------------#
pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
(4)armax
#---------------------------------------------------#
# 取出每一个像素点的种类
#---------------------------------------------------#
pr = pr.argmax(axis=-1)
(5)get image
:Image.fromarray(np.uint8())
seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
#------------------------------------------------#
# 将新图片转换成Image的形式
#------------------------------------------------#
image = Image.fromarray(np.uint8(seg_img))
#------------------------------------------------#
# 将新图与原图及进行混合
#------------------------------------------------#
image = Image.blend(old_img, image, 0.7)
三、完整代码
def detect_image(self, image, count=False, name_classes=None):#---------------------------------------------------------## 在这里将图像转换成RGB图像,防止灰度图在预测时报错。# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB#---------------------------------------------------------#image = cvtColor(image)#---------------------------------------------------## 对输入图像进行一个备份,后面用于绘图#---------------------------------------------------#old_img = copy.deepcopy(image)orininal_h = np.array(image).shape[0]orininal_w = np.array(image).shape[1]#---------------------------------------------------------## 给图像增加灰条,实现不失真的resize# 也可以直接resize进行识别#---------------------------------------------------------#image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]))#---------------------------------------------------------## 添加上batch_size维度#---------------------------------------------------------#image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)with torch.no_grad():images = torch.from_numpy(image_data)if self.cuda:images = images.cuda()#---------------------------------------------------## 图片传入网络进行预测#---------------------------------------------------#pr = self.net(images)[0]#---------------------------------------------------## 取出每一个像素点的种类#---------------------------------------------------#pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()#--------------------------------------## 将灰条部分截取掉#--------------------------------------#pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]#---------------------------------------------------## 进行图片的resize#---------------------------------------------------#pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)#---------------------------------------------------## 取出每一个像素点的种类#---------------------------------------------------#pr = pr.argmax(axis=-1)seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])#------------------------------------------------## 将新图片转换成Image的形式#------------------------------------------------#image = Image.fromarray(np.uint8(seg_img))#------------------------------------------------## 将新图与原图及进行混合#------------------------------------------------#image = Image.blend(old_img, image, 0.7)