一、两大法宝函数
1、dir()
打开,看见
2、help()
查看说明书
二、三个运行方式的区别

三、如何导入数据
两种数据形式:Dataset、Dataloader
Dataset
1、如何获取每一个数据及其label?
2、总共有多少条数据?
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
| from torch.utils.data import Dataset from PIL import Image import os
class MyData(Dataset):
def __init__(self, root_dir, label_dir): self.root_dir = root_dir self.label_dir = label_dir self.path = str(os.path.join(self.root_dir, self.label_dir)) self.img_path = os.listdir(self.path)
def __getitem__(self, idx): img_name = self.img_path[idx] img_item_path = os.path.join(self.root_dir,self.label_dir,img_name) img = Image.open(img_item_path) label = self.label_dir return img, label
def __len__(self): return len(self.img_path)
root_dir = "dataset/train" ants_label_dir = "ants" bees_label_dir = "bees" ants_dataset = MyData(root_dir, ants_label_dir) bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset
|
四、Tensorboard的使用
1 2 3 4 5 6 7 8
| from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")
for i in range(100): writer.add_scalar("y=2x", 2*i, i) writer.close()
|
如何打开tensorboard界面?
终端中:
tensorboard –logdir=”D:\pycharm\learn_pytorch\learn_pytorch\logs”
如何切换打开的端口(避免服务器训练时与别人冲突)
tensorboard –logdir=”D:\pycharm\learn_pytorch\learn_pytorch\logs” –port=6007(修改端口地址)
导入自己的图片
1 2 3 4 5 6 7 8 9 10 11 12
| from torch.utils.tensorboard import SummaryWriter import numpy as np from PIL import Image
writer = SummaryWriter("logs") image_path = "dataset/train/ants/0013035.jpg" img_PIL = Image.open(image_path) img_array = np.array(img_PIL) writer.add_image("test", img_array, 1, dataformats='HWC')
writer.close()
|
如果改变image的地址并且将writer.add_image("test", img_array, 1(横轴), dataformats='HWC')
中的横轴改为2,则tensorboard会在之前的图片上显示拖动条,向右拖动即为第二张图片


1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
| from PIL import Image from torch.utils.tensorboard import SummaryWriter from torchvision import transforms
writer = SummaryWriter("logs") img = Image.open("dataset/train/ants/0013035.jpg") print(img)
trans_totensor = transforms.ToTensor() img_totensor = trans_totensor(img) writer.add_image("ToTensor", img_totensor)
trans_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) img_norm = trans_norm(img_totensor) writer.add_image("Normalize", img_norm)
print(img.size) trans_resize = transforms.Resize((512, 512)) img_resize = trans_resize(img) img_resize = trans_totensor(img_resize) writer.add_image("Resize", img_resize, 0) print(img_resize)
trans_resize_2 = transforms.Resize(512) trans_compose = transforms.Compose([trans_resize_2, trans_totensor])
img_resize2 = trans_compose(img) writer.add_image("Resize", img_resize2, 1)
trans_randomCrop = transforms.RandomCrop((500, 1000)) trans_compose2 = transforms.Compose([trans_randomCrop, trans_totensor]) for i in range(10): img_crop = trans_compose2(img) writer.add_image("RandomCropHW", img_crop, i)
writer.close()
|
注:需要传递tensor类型图片
将图像的每个通道(RGB,共3个通道)按特定的均值和标准差进行归一化
Resize()
注:需要传递PIL类型图片
括号中只给一个数值,那么就将图片短的那个边匹配这个数值进行等比缩放
括号中给两个数值,就将长宽设置为这两个数值
Tips
关注输入和输出类型
关注方法需要什么参数
不知道返回值的时候,可以print()或print(type())或debug