环境配置
Anaconda
,创建conda create -n pytorch python=3.12
,使用conda activate pytorch
切换到环境。- 安装
pytorch
,conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
,使用import torch
,torch.cuda.is_available()
查看cuda
是否可用。
编辑器选择
- 之前一直用的
vscode
写python
,但看了视频之后发现pycharm
真好用: python console
这里可以直接查看变量及其属性
structure
这里能查看方法
terminal
是终端
快捷键
-
当报错的时候移到那里,
alt+enter
快捷修复,没导入包的时候好用
-
想查看属性和方法的时候移过去按住
ctrl
,点击,即可跳转
-
想查看方法要输入什么参数的时候
ctrl+P
两大法宝函数dir和help
dir
:列出所有属性和方法
help
:查看具体用法(也可以用??
)
jupyter notebook
- 在
pytorch
环境中pip install jupyter notebook
安装了jupyter notebook
,但没有添加环境变量不一定找得到,使用python -m jupyter notebook
自动找到适合当前环境的notebook
,细粒化程度高。
Dataset
使用PTL
读取图片,演示了一下Dataset
的做法,继承了Dataset
之后实现__getitem__
和__len__
方法,图片文件加目录为
from torch.utils.data import Dataset
from PIL import Image
import osclass MyData(Dataset):def __init__(self, root_dir, label_dir):self.root_dir = root_dirself.label_dir = label_dirself.path = os.path.join(self.root_dir, self.label_dir)self.img_path = os.listdir(self.path)def __getitem__(self, idx):img_name = self.img_path[idx]img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)img = Image.open(img_item_path)label = self.label_dirreturn img, labeldef __len__(self):return len(self.img_path)root_dir = ("dataset/train")
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)train_dataset = ants_dataset + bees_dataset
Tensorboard的使用
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Imagewriter = SummaryWriter("logs")
image_path = "data/train/bees_image/85112639_6e860b0469.jpg"
img_PIL = Image.open(image_path)
img_array = np.array(img_PIL)
print(type(img_array))
print(img_array.shape)writer.add_image("test", img_array, 2, dataformats='HWC')
# y = x
for i in range(100):writer.add_scalar("y=2x", 3 * i, i)writer.close()
在terminal
输入以下指令查看tensorboard
,可以自己定义文件夹名和端口号:
tensorboard --logdir=logs --port=6007
Transforms的一些用法
P9.transforms.py
from PIL import Image
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter# python的用法 -》 tensor数据类型
# 通过 transforms.ToTensor去解决两个问题
# 2. 为什么我们需要Tensor数据类型img_path = "data/train/ants_image/5650366_e22b7e1065.jpg"
img = Image.open(img_path)writer = SummaryWriter("logs")# 1. transforms该如何使用(python)
# 实例化ToTensor
tensor_trans = transforms.ToTensor()
# 调用实例,transforms.ToTensor的call方法
tensor_img = tensor_trans(img)print(tensor_img)writer.add_image("tensor_img", tensor_img)writer.close()# 使用opencv读图片
# import cv2
# cv_img = cv2.imread(img_path)
P10.UsefulTransforms.py
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transformswriter = SummaryWriter("logs")
img = Image.open("data/train/bees_image/16838648_415acd9e3f.jpg")
print(img)# ToTensor使用
trans_totensor = transforms.ToTensor()
img_tensor = trans_totensor(img)
writer.add_image("Totensor", img_tensor)# Normalize使用
print(img_tensor[0][0][0])
trans_norm = transforms.Normalize ([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
img_norm = trans_norm(img_tensor)
print(img_norm[0][0][0])
writer.add_image("Normalize", img_norm)# Resize使用
print(img.size)
trans_resize = transforms.Resize((512, 512))
img_resize = trans_resize(img)
# img_resize PIL -> totensor -> img_resize tensor
img_resize = trans_totensor(img_resize)
writer.add_image("Resize", img_resize, 0)
print(img_resize)# Compose - resize - 2
trans_resize_2 = transforms.Resize(512)
# PIL -> PIL -> tensor
trans_compose = transforms.Compose([trans_resize_2, trans_totensor])
img_resize_2 = trans_compose(img)
writer.add_image("Resize", img_resize_2, 1)# RandomCrop
trans_random = transforms.RandomCrop((200, 300))
trans_compose_2 = transforms.Compose([trans_random, trans_totensor])
for i in range(10):img_crop = trans_compose_2(img)writer.add_image("RandomCrop", img_crop, i)writer.close()
均是在利用transforms
处理图片,然后用TensorBoard
查看中间结果。
一些数据集的使用代码
import torchvision
from torch.utils.tensorboard import SummaryWriterdataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)# print(test_set[0])
# print(test_set.classes)
#
# img, target = test_set[0]
# print(img)
# print(target)
# print(test_set.classes[target])
# img.show()# print(test_set[0])writer = SummaryWriter("p10")
for i in range(10):img, target = test_set[i]writer.add_image("test_set", img, i)writer.close()
Pycharm的断点失灵
这个问题还没解决,不知道为什么断点无效,错误信息如下:
run
是没问题的,但不知道为什么debug
不行。
问题解决了,网上有各种解决方案都试了不行,最后更新了Pycharm
自动解决了:
本文参考小土堆教程视频。