PyTorch 介绍
PyTorch是由Feacbook开源,基于Torch二次开发的Python机器学习库,用于自然语言处理等应用程序。PyTorch既能够看作加入了GPU支撑的numpy,一起也能够看成一个拥有自动求导功用的强大的深度神经网络。
PyTorch 环境建立
1. 装置开发工具和python
pycharm下载
anaconda下载,装置教程自行百度~~
2. 创立项目
-
挑选 Conda 创立新环境
-
挑选 python 版别
-
创立完成后,能够看到当前环境
3. 装置PyTorch环境
-
咱们能够去PyTorch官方:https://pytorch.org/get-started/locally/ 进行挑选,然后履行官方指令装置。
-
也能够经过pip直接装置PyTorch
# cpu用户 pip installtorch==2.0.0 # gpu用户 # 只支撑英伟达的显卡 # 需求提早装置 cuda ,我这儿是11.7,一切下面指定cu117 # 假如已经装置了cuda不知道什么版别,能够经过 nvidia-smi 指令查看 pip install torch==2.0.0 --index-url https://download.pytorch.org/whl/cu117
在终端履行
-
验证PyTorch是否装置成功
在python控制台履行
import torch
,没有报错阐明装置成功 -
验证PyTorch是否支撑GPU
在python控制台履行
torch.cuda.is_available()
,回来 True 阐明能够运行在显卡上
4. 熟悉PyTorch都有什么
咱们能够经过 dir
和 help
来熟悉
-
经过
dir
看torch都有哪些工具包 -
经过
help
看torch.cuda.is_available的用法和阐明
PyTorch 学习
一、学习 Dataset
Dataset是什么东西,有什么作用
咱们将运用此 数据集 进行演示(也能够用自己的~~)
-
由所以图片数据集,咱们需求提早装置
opencv
pip install opencv-python==4.6.0.66
-
完成咱们Dataset
需求重写三个办法
__init__
类被创立会履行(类似Java构造器),咱们需求在这儿初始化数据__getitem__
根据索引查询,回来 label 和 image__len__
回来 数据集长度from torch.utils.data import Dataset from torch.utils.data.dataset import T_co import os import cv2 as cv # 读取咱们 label 文件的第一行内容 def read_label(path): file = open(path, "r", encoding='utf-8') label = file.readline() file.close() return label class MyDataset(Dataset): def __init__(self, train_path): # 给对象赋值,让其他办法也能够获取到 self.train_path = train_path self.image_path = os.path.join(train_path, 'image') self.label_path = os.path.join(train_path, 'label') # 拿到文件夹下一切图片名 self.image_path_list = os.listdir(self.image_path) def __getitem__(self, index) -> T_co: # 读取图片 image_name = self.image_path_list[index] image_path = os.path.join(self.image_path, image_name) img = cv.imread(image_path) # 读取图片对应的label label_name = 'txt'.join(image_name.rsplit(image_name.split('.')[-1], 1)) label_path = os.path.join(self.label_path, label_name) label = read_label(label_path) return img, label def __len__(self): # 回来数据集长度 return len(self.image_path_list) # 测验 创立 MyDataset 对象 my_dataset = MyDataset("dataset/train") # 拿到下标100的 image 和 label data_index = 100 img, label = my_dataset[data_index] # 展现出来 咱们这儿用到了 __len__ cv.imshow(label + ' (' + str(data_index) + '/' + str(len(my_dataset)) + ')', img) cv.waitKey(0)
作用如下:
二、下期预告
1. TensorBoard 的运用
2. Transforms 的运用
会尽快更新~~~