用cityscapes fine tune yolov8-seg

cityscapes数据集预处理

import os
import random
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transformsdef get_subfolders_with_path(folder_path):subfolders = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, f))]subfolders.sort()return subfoldersdef get_files_in_folder(folder_path):files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]files.sort()return filesclass CustomDataset(Dataset):def __init__(self):# 指定文件夹路径folder_path = '/home/Downloads/cityspaces/leftImg8bit/train/'# 获取带路径的文件夹列表folder_list_with_path = get_subfolders_with_path(folder_path)self.all_images_path = []# 获取文件列表for folder_i in folder_list_with_path:file_list = get_files_in_folder(folder_i)self.all_images_path.extend(file_list)def __len__(self):return len(self.all_images_path)def __getitem__(self, item):name_i = self.all_images_path[item]my_string = name_i[:-4]# 找到第一个 "leftImg8bit" 的索引位置first_index = my_string.find("leftImg8bit")# 找到第二个 "leftImg8bit" 的索引位置,从第一个之后开始搜索second_index = my_string.find("leftImg8bit", first_index + 1)# 使用切片和 replace() 方法替换第二个 "leftImg8bit" 为 "gt_Fine_labelids.png"new_string = my_string[:second_index] + "gtFine_labelIds.png" + my_string[second_index + len("leftImg8bit"):]# 替换第一个leftImg8bitlabel_path = new_string.replace("leftImg8bit", "gtFine", 1)image = cv2.imread(name_i)image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)# plt.imshow(image)# plt.show()label = cv2.imread(label_path, 0)input_width = 512input_height = 512image = cv2.resize(image, (input_width, input_height), interpolation=cv2.INTER_NEAREST)mean = np.array([0.485, 0.456, 0.406]) * 255std = np.array([0.229, 0.224, 0.225]) * 255# 对图像进行归一化image = (image - mean) / std# 0-33 34个标签  yoloV8-seg输出0-31 32个标签label[label > 31] = 31label[label == 1] = 7label = cv2.resize(label, (512, 512), interpolation=cv2.INTER_NEAREST)# plt.subplot(121)# plt.imshow(image)# plt.subplot(122)# plt.imshow(label)# plt.show()label = torch.from_numpy(label)return image, labeldef get_dataloader():batch_size = 8train_set = CustomDataset()train_loader_ = DataLoader(train_set, batch_size=batch_size, shuffle=False, drop_last=False)return train_loader_if __name__ == "__main__":train_loader = get_dataloader()for batch_idx, (data, target) in enumerate(train_loader):print(batch_idx)

加载yolo8-seg模型

# Load YOLOv8n-seg, train it on COCO128-seg for 3 epochs and predict an image with it
from ultralytics import YOLO
import matplotlib.pyplot as pltmodel = YOLO('yolov8n-seg.pt')  # load a pretrained YOLOv8n segmentation model
# Train the model
# results = model.train(data='coco128-seg.yaml', epochs=100, imgsz=640)
output = model("/home/robotics/dino/img/IMAGE0000016.jpg")  # predict on an image

这样输出的output是没有梯度的,不能训练。想要训练就要调用注释掉的train方法,需要提前按照coco格式准备好数据集,如果不想制作coco数据集的格式,通过这种方法拿出模型的带梯度的输出

from ultralytics import YOLO
import torch
import numpy as np
import cv2
import torch.nn.functional as F
import matplotlib.pyplot as plt# Load a model
yoloSeg = YOLO('yolov8x-seg.yaml').load('yolov8x-seg.pt')  # build from YAML and transfer weightsname = "/home/dino/img/student_building/b1.jpg"
img0 = cv2.imread(name)
img0 = cv2.cvtColor(img0, cv2.COLOR_BGR2RGB)
img0 = cv2.resize(img0, (512, 512))mean = np.array([0.485, 0.456, 0.406]) * 255
std = np.array([0.229, 0.224, 0.225]) * 255
# 对图像进行归一化
img = (img0 - mean) / stdimg = torch.from_numpy(img).unsqueeze(0).to(torch.float32).to("cuda")
img = img.permute(0, 3, 1, 2)yoloSeg.model = yoloSeg.model.to("cuda")output = yoloSeg.model(img)
result = output[2]result = F.interpolate(result, size=(512, 512), mode='nearest')max_indices = torch.argmax(result, dim=1)
result = torch.squeeze(max_indices, dim=1)
result = result.cpu().numpy()
result = result.astype(np.uint8)
result = result.squeeze()plt.subplot(121)
plt.imshow(img0)
plt.subplot(122)
plt.imshow(result, cmap='viridis')
plt.title('output'), plt.xticks([]), plt.yticks([])
plt.show()print("done")

调用的函数是ultralytics/nn/tasks.py中 BaseModel的forward方法

对模型进行训练

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from prepare_data import get_dataloader
from ultralytics import YOLOif __name__ == "__main__":device = torch.device("cuda")writer = SummaryWriter('./log')  # 指定日志保存的目录train_loader = get_dataloader()criterion = nn.CrossEntropyLoss()yoloSeg = YOLO('yolov8x-seg.yaml').load('yolov8x-seg.pt')  # build from YAML and transfer weightsmodel = yoloSeg.model.to("cuda")optimizer = optim.AdamW(model.parameters(), lr=1e-5)num_epochs = 300for epoch in range(1, num_epochs+1):epoch_loss = 0.0for batch_idx, (data, target) in enumerate(train_loader):data = data.to(device).float()data = data.permute(0, 3, 1, 2)target = target.to(device).to(torch.int64)output = model(data)[2]output = F.interpolate(output, size=(512, 512), mode='nearest')loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()epoch_loss += loss.item()if batch_idx % 10 == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch, num_epochs, batch_idx+1, len(train_loader), loss.item()))avg_epoch_loss = epoch_loss / len(train_loader)writer.add_scalar('Training Loss', avg_epoch_loss, epoch)if epoch % 10 == 0:torch.save(model.state_dict(), "./my_checkpoints/my_train_temp.pth")print(f"Model weights saved.")writer.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/648333.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode410.Split Array Largest Sum——二分答案

文章目录 一、题目二、题解 一、题目 Given an integer array nums and an integer k, split nums into k non-empty subarrays such that the largest sum of any subarray is minimized. Return the minimized largest sum of the split. A subarray is a contiguous part…

DDColor:AI图像着色工具,优秀的黑白图像上色模型,支持双解码器!

前言 在数字图像处理领域,图像上色 一直是一个重要的课题。传统的图像上色方法通常需要人工干预,耗时且效果有限。 然而,随着深度学习技术的发展,自动图像上色模型逐渐成为了研究热点。 其中,DDColor 图像上色模型以…

vue3+elementPlus pc和小程序ai聊天文生图

websocket封装可以看上一篇文章 //pc端 <template><div class"common-layout theme-white"><el-container><el-aside><div class"title-box"><span>AI Chat</span></div><div class"chat-list&…

openssl3.2/test/certs - 060 - any.bad.com is excluded by CA2.

文章目录 openssl3.2/test/certs - 060 - any.bad.com is excluded by CA2.概述笔记END openssl3.2/test/certs - 060 - any.bad.com is excluded by CA2. 概述 openssl3.2 - 官方demo学习 - test - certs 笔记 /*! * \file D:\my_dev\my_local_git_prj\study\openSSL\test…

iOS推送通知

文章目录 一、推送通知的介绍1. 简介2. 通知的分类 二、本地通知1. 本地通知的介绍2. 实现本地通知3. 监听本地通知的点击 三、远程通知1. 什么是远程通知2. 为什么需要远程通知3. 远程通知的原理4. 如何做远程通知5. 远程通知证书配置6. 获取远程推送要用的 DeviceToken7. 测试…

java 利用gdal,转换shpfile的坐标系

文章目录 简要说明maven依赖样例代码 简要说明 在java开发中&#xff0c;利用gdal&#xff0c;将原本shpfile的坐标系转为想要的坐标系&#xff0c;并输出新的shpfile maven依赖 <!--需要安装完gdal后&#xff0c;本地install gdal包才能使用 --><!--gdal安装可参考…

代码随想录算法训练营第30天 | 回溯总结 + 3道Hard题目

今日任务 332.重新安排行程 51. N皇后 37. 解数独 总结 总结 回溯总结&#xff1a;代码随想录 回溯是递归的副产品&#xff0c;只要有递归就会有回溯&#xff0c;所以回溯法也经常和二叉树遍历&#xff0c;深度优先搜索混在一起&#xff0c;因为这两种方式都是用了递归。 …

《WebKit 技术内幕》学习之十三(1):移动WebKit

1 触控和手势事件 1.1 HTML5规范 随着电容屏幕的流行&#xff0c;触控操作变得前所未有的流行起来。时至今日&#xff0c;带有多点触控功能已经成为了移动设备的标准配置&#xff0c;基于触控的手势识别技术也获得巨大的发展&#xff0c;如使用两个手指来缩放应用的大小等。…

DAY11_(简易版)VUEElement综合案例

目录 1 VUE1.1 概述1.1.1 Vue js文件下载 1.2 快速入门1.3 Vue 指令1.3.1 v-bind & v-model 指令1.3.2 v-on 指令1.3.3 条件判断指令1.3.4 v-for 指令 1.4 生命周期1.5 案例1.5.1 需求1.5.2 查询所有功能1.5.3 添加功能 2 Element2.0 element-ui js和css和字体图标下载2.1 …

Vulnhub靶场DC-6

攻击机192.168.223.128 靶机192.168.223.134 主机发现:nmap -sP 192.168.223.0/24 端口扫描 nmap -sV -p- -A 192.168.223.134 开启了22 80端口&#xff0c;80是apache 2.4.25 先进入web界面看一下 用ip进不去&#xff0c;应该被重定向到wordy.com vim /etc/hosts 加上 19…

[SAP ABAP] ABAP编程中SY-SUBRC值的含义

在ABAP编程中&#xff0c;SY-SUBRC是一个系统变量&#xff0c;用于表示最近一次执行的系统命令(例如数据库操作、函数模块调用等)的结果状态码 SY-SUBRC的值用于检查命令是否执行成功&#xff0c;通常用于控制程序的流程 查询数据 使用SELECT语句选择查询 SY-SUBRC 0 &qu…

亚信安慧AntDB构建未来数据库典范

亚信安慧AntDB是一款数据库管理系统&#xff0c;它采用全球影响力大、社区繁荣、开放度高、生态增长迅速的PG内核。这款系统具有卓越的性能和稳定性&#xff0c;在全球范围内备受用户青睐。 与此同时&#xff0c;AntDB的社区也是充满活力的&#xff0c;用户可以在社区中交流经…

Vue中使用TypeScript:全面指南和最佳实践

🚀 欢迎来到我的专栏!专注于Vue3的实战总结和开发实践分享,让你轻松驾驭Vue3的奇妙世界! 🌈✨在这里,我将为你呈现最新的Vue3技术趋势,分享独家实用教程,并为你解析开发中的难题。让我们一起深入Vue3的魅力,助力你成为Vue大师! 👨‍💻💡不再徘徊,快来关注…

WebSocket服务端数据推送及心跳机制(Spring Boot + VUE):

文章目录 一、WebSocket简介&#xff1a;二、WebSocket通信原理及机制&#xff1a;三、WebSocket特点和优点&#xff1a;四、WebSocket心跳机制&#xff1a;五、在后端Spring Boot 和前端VUE中如何建立通信&#xff1a;【1】在Spring Boot 中 pom.xml中添加 websocket依赖【2】…

都学Python了,C++难道真的用不着了吗?

都学Python了&#xff0c;C难道真的用不着了吗&#xff1f; 在开始前我分享下我的经历&#xff0c;刚入行时遇到一个好公司和师父&#xff0c;给了我机会&#xff0c;两年时间从3k薪资涨到18k的&#xff0c; 我师父给了一些【C 】学习方法和资料&#xff0c;让我不断提升自己…

python笔记6

目录 字符串的编码和解码 编码&#xff08;Encode&#xff09;&#xff1a; 解码&#xff08;Decode&#xff09;&#xff1a; replace,ignore,strict 字符串的编码和解码 字符串的编码和解码是涉及将文本数据转换为字节序列或将字节序列转换为文本数据的过程。在Python中…

单片机介绍

本文为博主 日月同辉&#xff0c;与我共生&#xff0c;csdn原创首发。希望看完后能对你有所帮助&#xff0c;不足之处请指正&#xff01;一起交流学习&#xff0c;共同进步&#xff01; > 发布人&#xff1a;日月同辉,与我共生_单片机-CSDN博客 > 欢迎你为独创博主日月同…

关于axios给后端发送数据的问题

这里需要用的插件&#xff1a;qs.js&#xff0c;是前端给后端发送的数组&#xff0c;需要序列化所以要用到这个插件&#xff0c;这里就提取连接在这里&#xff0c;需要的自提&#xff0c;需要导如进来&#xff0c;别忘记了 链接&#xff1a;https://pan.baidu.com/s/1qyD8v9wfd…

拓展全球市场:静态代理IP成为跨境电商战略的关键工具

&#x1f935;‍♂️ 个人主页&#xff1a;艾派森的个人主页 ✍&#x1f3fb;作者简介&#xff1a;Python学习者 &#x1f40b; 希望大家多多支持&#xff0c;我们一起进步&#xff01;&#x1f604; 如果文章对你有帮助的话&#xff0c; 欢迎评论 &#x1f4ac;点赞&#x1f4…

EIGRP实验

实验大纲 一、基本配置 1.构建网络拓扑结构图 2.路由器基本配置 3.配置PC 4.测试连通性 5.保存配置文件 二、配置EIGRP 1.查看路由表 2.配置EIGRP动态路由 3.查看路由器路由表 4.测试网络连通性 5.查看所有路由器的路由协议 6.保存配置文件 三、配置OSPF 1.配置…