第一种做法:
import os
import sys
import random
import math
import numpy as np
import skimage.io
import matplotlib
import matplotlib.pyplot as plt
import cv2
import colorsys
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
# Root directory of the project
ROOT_DIR = os.path.abspath("./")
print(ROOT_DIR)
# Import Mask RCNN
sys.path.append(ROOT_DIR) # To find local version of the library
from mrcnn import utils
import mrcnn.model as modellib
from mrcnn import visualize
# # Import COCO config
sys.path.append(os.path.join(ROOT_DIR, "samples/coco/")) # To find local version
import coco# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs")# Local path to trained weights file
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
print(COCO_MODEL_PATH)IMAGE_DIR = os.path.join(ROOT_DIR, "images")class InferenceConfig(coco.CocoConfig):# Set batch size to 1 since we'll be running inference on# one image at a time. Batch size = GPU_COUNT * IMAGES_PER_GPUGPU_COUNT = 1IMAGES_PER_GPU = 1config = InferenceConfig()
config.display()def random_colors(N, bright=True):"""Generate random colors.To get visually distinct colors, generate them in HSV space thenconvert to RGB."""brightness = 1.0 if bright else 0.7hsv = [(i / N, 1, brightness) for i in range(N)]colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))random.shuffle(colors)return colors
def apply_mask(image, mask,color,alpha=0.5):"""Apply the given mask to the image."""for c in range(3):image[:, :, c] = np.where(mask == 1,image[:, :, c] *(1 - alpha) + alpha*color[c]* 255,image[:, :, c])return image
def load_model():model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)# Load weights trained on MS-COCOmodel.load_weights(COCO_MODEL_PATH, by_name=True)# COCO Class names# Index of the class in the list is its ID. For example, to get ID of# the teddy bear class, use: class_names.index('teddy bear')class_names = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane','bus', 'train', 'truck', 'boat', 'traffic light','fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird','cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear','zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie','suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball','kite', 'baseball bat', 'baseball glove', 'skateboard','surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup','fork', 'knife', 'spoon', 'bowl', 'banana', 'apple','sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza','donut', 'cake', 'chair', 'couch', 'potted plant', 'bed','dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote','keyboard', 'cell phone', 'microwave', 'oven', 'toaster','sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors','teddy bear', 'hair drier', 'toothbrush']# file_names = next(os.walk(IMAGE_DIR))[2]file_names=['3627527276_6fe8cd9bfe_z.jpg']image = cv2.imread(os.path.join(IMAGE_DIR, random.choice(file_names)))# Run detectionresults = model.detect([image], verbose=1)# Visualize resultsr = results[0]print(r['rois'])print('======================')print(np.array(r['masks']).shape)print('========================')print(r['class_ids'])print('========================')print(r['scores'])# visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'],# class_names, r['scores'],ax=False)boxes=r['rois']masks=r['masks']class_ids=r['class_ids']scores=r['scores']N = r['rois'].shape[0]colors = random_colors(N)print('colors=', colors)for i in range(N):ymin, xmin, ymax, xmax = boxes[i]print(boxes[i])color = colors[i]# Maskmask = masks[:, :, i]image = apply_mask(image, mask, color)# add textclass_id = class_ids[i]score = scores[i] if scores is not None else Nonelabel = class_names[class_id]caption = "{} {:.3f}".format(label, score) if score else labelcv2.putText(image, caption, (xmin, ymin + 8), cv2.FONT_HERSHEY_SIMPLEX, 0.3, color=(255, 255, 255), thickness=1)cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (255, 0, 0), 2)cv2.imwrite('1.jpg', image)
if __name__ == '__main__':load_model()
注意np.where的妙用,其中r['masks']是(none,none,4)的ndarray形式的mask,通过取每一个channel盖在原图片上。
第二种做法:
#coding:utf-8
import numpy as np
import matplotlib.pyplot as plt
import cv2
import imutils
from matplotlib.patches import Polygon
import osdef get_contour(image_mask,black_mask,image,color):global dummy_mask,line_maskimage_thre = cv2.threshold(image* 255, 127, 255, cv2.THRESH_BINARY)[1]# cnts=cv2.findContours(# image_thre.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)# print('cnts=',cnts.shape)cnts = cv2.findContours(image_thre, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)contours = cnts[0] if imutils.is_cv2() else cnts[1]c_ = sorted(contours, key=cv2.contourArea, reverse=True)for i in c_:if i.shape[0]>120:pig_cnt = np.squeeze(i)"""change color"""line_mask=cv2.drawContours(image_mask, [pig_cnt], -1, color, 5, cv2.LINE_AA)fill_mask=cv2.drawContours(black_mask, [pig_cnt], 0, (0,0,240), thickness=cv2.FILLED)# dummy_mask=cv2.add(line_mask,fill_mask)dummy_mask = cv2.addWeighted(line_mask,1., fill_mask,0.4,0)# dummy_mask = cv2.drawContours(dummy_mask, [pig_cnt], 0, (random_b, random_g, random_r), thickness=cv2.FILLED)return dummy_mask
def get_channel_color(channels):color_dict={}for i in range(channels):random_b = int(np.random.randint(0, 255, 1)[0])random_g = int(np.random.randint(0, 255, 1)[0])random_r = int(np.random.randint(0, 255, 1)[0])color_dict[i]=(random_b,random_g,random_r)return color_dictdef pig_contour():img_path='./image_low_clear'show_path='./show_low_clear'output_path='./image_mask_low_clear'if not os.path.exists(output_path):os.mkdir(output_path)show_list_path = [os.path.join(show_path, i) for i in os.listdir(show_path)]channels_list=[]for i in show_list_path:channel=np.load(i).shape[-1]channels_list.append(channel)channel=max(channels_list)print('channel=',channel)color_dict = get_channel_color(channels=channel)imgs_list_path = [os.path.join(img_path, i) for i in os.listdir(img_path)]for i,img_list_path in enumerate(imgs_list_path):pig_mask_list_path = show_path+'/'+img_list_path.split('/')[-1].replace('jpg','npy')print('img_list_path=',img_list_path)print('pig_mask_list_path=',pig_mask_list_path)if pig_mask_list_path.split('/')[-1] in os.listdir(show_path):pig_mask = np.load(pig_mask_list_path)image_h, image_w, channels = pig_mask.shapeblack_mask = np.zeros((image_h, image_w, 3),dtype=np.float32)image=cv2.imread(img_list_path)image_mask=image.astype(np.float32)for i in range(channels):global dummy_mask# print('i=',i)dummy_mask=get_contour(image_mask,black_mask, pig_mask[..., i],color=color_dict[i])cv2.imwrite(output_path+'/'+img_list_path.split('/')[-1], dummy_mask)def unlock_mv(sp):""" 将视频转换成图片sp: 视频路径 """cap = cv2.VideoCapture(sp)suc = cap.isOpened() # 是否成功打开frame_count = 0while suc:frame_count += 1suc, frame = cap.read()params = []params.append(2) # params.append(1)cv2.imwrite('mv\\%d.jpg' % frame_count, frame, params)cap.release()print('unlock image: ', frame_count)# 图片转视频
def jpg_video():""" 将图片合成视频. sp: 视频路径,fps: 帧率 """image_path = './image_mask_low_clear'images_list_path=[os.path.join(image_path,i) for i in os.listdir(image_path)]images_list_path=sorted(images_list_path,key=lambda x:int(x.split('/')[-1].split('.')[0]))# print(images_list_path)h,w,_=cv2.imread(images_list_path[0]).shapefps=4fourcc = cv2.VideoWriter_fourcc(*'MJPG')videoWriter = cv2.VideoWriter('predict_low_clear.avi', fourcc, fps, (w, h)) # 最后一个是保存图片的尺寸for i,image_list_path in enumerate(images_list_path):frame = cv2.imread(image_list_path)videoWriter.write(frame)videoWriter.release()#视频转图片 interval是fps加1,如果去合成视频fps==6
def video2frame():video_src_path = "./data/xiaomi8_videos"# video_formats = [".MP4", ".MOV"] 我的数据集都是.mp4所以不需要进行分类判断frame_save_path = "./data/xiaomi8_images/"if not os.path.exists(frame_save_path):os.mkdir(frame_save_path)interval = 5"""将视频按固定间隔读取写入图片:param video_src_path: 视频存放路径:param formats: 包含的所有视频格式:param frame_save_path: 保存路径:param frame_width: 保存帧宽:param frame_height: 保存帧高:param interval: 保存帧间隔:return: 帧图片"""videos = os.listdir(video_src_path)for each_video in videos:print("正在读取视频:", each_video)each_video_name = each_video[:-4]print(each_video_name)if not os.path.exists(frame_save_path + each_video_name):os.mkdir(frame_save_path + each_video_name)each_video_save_full_path = os.path.join(frame_save_path, each_video_name) + "/"each_video_full_path = os.path.join(video_src_path, each_video)cap = cv2.VideoCapture(each_video_full_path)frame_index = 0frame_count = 0if cap.isOpened():success = Trueelse:success = Falseprint("读取失败!")while(success):success, frame = cap.read()print("---> 正在读取第%d帧:" % frame_index, success) # 我的是Python3.6if frame_index % interval == 0 and success: # 如路径下有多个视频文件时视频最后一帧报错因此条件语句中加and success# resize_frame = cv2.resize(frame, (frame_width, frame_height), interpolation=cv2.INTER_AREA)cv2.imwrite(each_video_save_full_path + "%d.png" % frame_count, frame)frame_count += 1frame_index += 1cap.release()if __name__ == '__main__':pig_contour()# jpg_video()