数据集简介
本数据拥有
训练集:43685张;
验证集:5363张;
测试集:5363张;
总类别数:158类。
部分代码:
定义数据集
class MyDataset(Dataset):def __init__(self, mode='train', transform=None):super(MyDataset, self).__init__()self.data = []self.transform = transformwith open(f'{data_path}{mode}.txt') as f:for line in f.readlines():info = line.strip().split(' ')if len(info) > 0:self.data.append([data_path+'/'+info[0].strip(), info[1].strip()])def __getitem__(self, idx):image_file, label = self.data[idx]img = Image.open(image_file).convert('RGB')img = np.array(img)# (Tensor(shape=[3, 227, 227], dtype=float32, place=CUDAPlace(0), stop_gradient=True,if self.transform is not None:img = self.transform(img)label = np.array([label], dtype="int64")return img, labeldef __len__(self):
定义ResNet网络
resnet50 = paddle.vision.models.resnet50(num_classes=158)
取单张测试图片进行可视化展示
import pylab as pl
import matplotlib.font_manager as fmtest_path = '/home/aistudio/Mydata/test1.txt'
myfont = fm.FontProperties(fname=r'/home/aistudio/simkai.ttf') # 设置字体
jetson_path = '/home/aistudio/Mydata/garbage_classification.json'
with open(jetson_path, 'r') as f1:load_dict = json.load(f1)
with open(test_path, 'r') as f2:img_path = f2.readline().strip().split(' ')
test_img_path = '/home/aistudio/Mydata/' + f'{img_path[0]}'
print('输入测试图片路径为:')
print(test_img_path)
clas = load_dict[f'{lab1}']#从字典中查找标签0对应的垃圾种类
img = cv2.imread(test_img_path)
plt.imshow(img)
plt.title(f'预测:{clas}', fontproperties = myfont, fontsize=20)