1. 将shp的标签数据转成coco
# -*- coding: utf-8 -*-
import os, json
import cv2
from osgeo import gdal
import numpy as np
from osgeo import ogr, gdal, osr
from shapely.geometry import box, shape
from shapely.geometry.polygon import Polygon
import collections
import datetime
import geopandas as gpd
import shutildef read_img(filename):dataset=gdal.Open(filename)im_width = dataset.RasterXSizeim_height = dataset.RasterYSizeim_geotrans = dataset.GetGeoTransform()im_proj = dataset.GetProjection()im_data = dataset.ReadAsArray(0,0,im_width,im_height)# del dataset return im_width, im_height, im_proj, im_geotrans, im_data, datasetdef write_img(filename,im_proj,im_geotrans,im_data):if 'int8' in im_data.dtype.name:datatype = gdal.GDT_Byteelif 'int16' in im_data.dtype.name:datatype = gdal.GDT_UInt16else:datatype = gdal.GDT_Float32if len(im_data.shape) == 3:im_bands, im_height, im_width = im_data.shapeelse:im_bands, (im_height, im_width) = 1,im_data.shape driver = gdal.GetDriverByName("GTiff")dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)dataset.SetGeoTransform(im_geotrans)dataset.SetProjection(im_proj)if im_bands == 1:dataset.GetRasterBand(1).WriteArray(im_data)else:for i in range(im_bands):dataset.GetRasterBand(i+1).WriteArray(im_data[i])def data2YoloAndCoco(shapefile_path, tif_path): full_name = os.path.split(tif_path)[1]name = full_name[:-4]# 打开Shapefile文件shapefile_ds = ogr.Open(shapefile_path)if shapefile_ds is None:print("无法打开Shapefile文件")returnshapefile_layer = shapefile_ds.GetLayer()#feature_num = shapefile_layer.GetFeatureCount() # get poly count# 打开TIFF文件获取地理转换信息tif_ds = gdal.Open(tif_path)if tif_ds is None:print("无法打开TIFF文件")returnwidth = tif_ds.RasterXSizeheight = tif_ds.RasterYSizeyolo_label_path = os.path.join(yolo_txt_path, name + ".txt")txt = open(yolo_label_path, 'w')def get_bbox_points(ring, geo_transform, x_res, y_res, width, height):corner1 = ring.GetPoint(0)corner2 = ring.GetPoint(1)corner3 = ring.GetPoint(2)corner4 = ring.GetPoint(3)# print(corner1, corner2, corner3, corner4)# 计算像素坐标,考虑分辨率pixel_corner1 = (int((corner1[0] - geo_transform[0]) / x_res), int((corner1[1] - geo_transform[3]) / y_res))pixel_corner2 = (int((corner2[0] - geo_transform[0]) / x_res), int((corner2[1] - geo_transform[3]) / y_res))pixel_corner3 = (int((corner3[0] - geo_transform[0]) / x_res), int((corner3[1] - geo_transform[3]) / y_res))pixel_corner4 = (int((corner4[0] - geo_transform[0]) / x_res), int((corner4[1] - geo_transform[3]) / y_res))x1, y1_ = pixel_corner1x2, y2_ = pixel_corner2x3, y3_ = pixel_corner3x4, y4_ = pixel_corner4y1 = y1_y2 = y2_y3 = y3_y4 = y4_# print(x1,y1,x2,y2,x3,y3,x4,y4)w = x2 - x1h = y3 - y2x_center = x1 + w/2.0y_center = y2 + h/2.0x_normalized = abs(x_center / width)y_normalized = abs(y_center / height)width_normalized = abs(w / width)height_normalized = abs(h / height)return x_normalized, y_normalized, width_normalized, height_normalized#return x1,y1,x2,y2,x3,y3,x4,y4def get_boundary_points(geom, geo_transform, x_res, y_res):points = [] # store points in real worldpixels = [] # store pixels in imagesx_pixels = []y_pixels = []feature_type = geom.GetGeometryName() #feature_type: LINEARRINGfor j in range(geom.GetPointCount()):px = geom.GetX(j)py = geom.GetY(j)points.append((px, py))for p in points:new_pixel_x = int((p[0] - geo_transform[0]) / x_res)new_pixel_y = int((p[1] - geo_transform[3]) / y_res)x_pixels.append(new_pixel_x)y_pixels.append(new_pixel_y)pixels.append([new_pixel_x, new_pixel_y])return x_pixels, y_pixels, pixelsdef getsegmenation(x_pixels, y_pixels):getsegmenation_list = []minx = min(x_pixels)maxx = max(x_pixels)miny = min(y_pixels)maxy = max(y_pixels)box_w = maxx - minxbox_h = maxy - minybounding_box_area = box_w * box_hbox_info = [minx, miny, box_w, box_h]getsegmenation = [[minx, miny], [maxx, miny], [minx, maxy], [maxx, maxy]]getsegmenation = np.asarray(getsegmenation).flatten().tolist() #segmentation[[x1,y1,x2,y2,...]]getsegmenation_list.append(getsegmenation)return box_info, bounding_box_area, getsegmenation_listif __name__ == "__main__":now = datetime.datetime.now()# 定义coco数据格式data = dict(info=dict(description=None,url=None,version=None,year=now.year,contributor=None,date_created=now.strftime('%Y-%m-%d %H:%M:%S.%f'),),licenses=[dict(url=None,id=0,name=None,)],images=[# license, url, file_name, height, width, date_captured, id],type='instances',annotations=[# segmentation, area, iscrowd, image_id, bbox, category_id, id],categories=[# supercategory, id, name],
)# 定义类别信息#class_names = ["pine", "spruce", "birch", "populus"]cls_dict = {'1':'pine', '2':'spruce', '3':'birch', '4':'populus'}for i, class_name in enumerate(class_names):data["categories"].append({"id": i + 1,"name": class_name,"supercategory": ""})root_tiff_folder = './data4train/train_image_128/'root_shpf_folder = './data4train/train_label_128/'out_json_file = './STDtrain128.json'image_id = 0for sitname in os.listdir(root_shpf_folder):for regionn in os.listdir(os.path.join(root_shpf_folder, sitname)):#tiff_folder = './data4train/train_image_128/'#shpf_folder = './data4train/train_label_128/'shpf_folder = os.path.join(root_shpf_folder, sitname, regionn)tiff_folder = os.path.join(root_tiff_folder, sitname, regionn)# # 遍历每个shp文件for shpfile in os.listdir(shpf_folder): if shpfile[-4:] == ".shp":print('Processing shpfile:', shpfile) shpfile_path = os.path.join(shpf_folder, shpfile)shpfile_name, shpfile_ext = os.path.splitext(shpfile)#siten, regionn, mark_ = shpfile_name.split('-')#tiffile_name = siten + '_' + regionn + '_deno.tif'#shpfile_path = './train_shp/jokisalo_region1_deno_1.shp'tiffile_name = shpfile_name + '.tif'tiffile_path = os.path.join(tiff_folder, tiffile_name)#tiffile_path = './train_img/jokisalo_region1_deno_1.tif'#txt_path = os.path.join(txt_folder, txt_name)#print('tiffile_path:', tiffile_path)#copy tiff image files to a new folder tiffile_to_path = os.path.join('./data4train/TIFFImage-train-128/', tiffile_name)shutil.copy(tiffile_path, tiffile_to_path)dataset = gdal.Open(tiffile_path)im_width = dataset.RasterXSizeim_height = dataset.RasterYSizedata['images'].append(dict(license=0,url=None,file_name=tiffile_name,height = im_width,width = im_height,date_captured=None,id=image_id,))# 打开Shapefile文件shapefile_ds = ogr.Open(shpfile_path)#gdf = gpd.read_file(shpfile_path)if shapefile_ds is None:print("无法打开Shapefile文件")pass# 获取字段信息shapefile_layer = shapefile_ds.GetLayer()layer_defn = shapefile_layer.GetLayerDefn()num_fields = layer_defn.GetFieldCount()#feature_num = shapefile_layer.GetFeatureCount() # get poly count# 打印字段信息#for i in range(num_fields):#field_defn = layer_defn.GetFieldDefn(i)#print(f"字段名称: {field_defn.GetName()}, 类型: {field_defn.GetTypeName()}")#exit(0)# 打开TIFF文件获取地理转换信息geo_transform = dataset.GetGeoTransform()# 分辨率x_res = geo_transform[1]y_res = geo_transform[5]# 遍历每个要素#bbox_id = 0x_pixels = []y_pixels = []for feature in shapefile_layer:#print(feature)#exit(0)geometry = feature.GetGeometryRef()ring = geometry.GetGeometryRef(0) class_id = str(feature.GetField("Class")) #The field retore class id if class_id in list(cls_dict.keys()): #['1','2','3','4']label = cls_dict[class_id]feature_type = ring.GetGeometryName()x_pixels, y_pixels, point_pixel = get_boundary_points(ring, geo_transform, x_res, y_res) # get xy of each featureif len(x_pixels) > 0 and len(y_pixels) > 0:bbox, area, bbox_points = getsegmenation(x_pixels, y_pixels)# 将边界框信息保存到COCO格式的字典中data["annotations"].append({"id": len(data['annotations']),"image_id": image_id,"category_id": int(class_id),"segmentation": bbox_points,"area": area,"bbox": bbox,"iscrowd": 0})#bbox_id += 1else:print('class_id is empty!')image_id += 1shapefile_ds = Nonedataset = Nonewith open(out_json_file, 'w') as f:json.dump(data, f)f.close()