南京网站设计搭建公司,金华大企业网站建设有哪些,网站开发的对联,wordpress 推送到微信教程#xff1a;现有网络模型的使用及修改_哔哩哔哩_bilibili
官方网址#xff1a;https://pytorch.org/vision/stable/models.html#classification 初识网络模型
pytorch为我们提供了许多已经构造好的网络模型#xff0c;我们只要将它们加载进来#xff0c;就可以直接使… 教程现有网络模型的使用及修改_哔哩哔哩_bilibili
官方网址https://pytorch.org/vision/stable/models.html#classification 初识网络模型
pytorch为我们提供了许多已经构造好的网络模型我们只要将它们加载进来就可以直接使用。以torchvision为例关于神经网络处理图像的模型就分为好几个大类如图像分类、目标检测、语义分割等等。如图所示 视频中的讲解以VGG模型为例来向我们展示了网络模型的使用。
因为这个教学视频也已经是两三年前了的现在和之前略微有所区别。在这里简单做一个说明比如说模型加载过程中参数的改变 如今的模型中不再有pretrained参数也就是如果大家需要下载模型的权重文件需要自己手动下载。务必注意写了会报错哦。
权重文件的下载 视频中有讲到模型的下载也是不大不小的如果不进行设置一般会默认下载在c盘想要进行设置的话可以在网上搜索有关代码Pytorch预训练模型下载并加载以VGG为例自定义路径_怎么更改vgg下载路径-CSDN博客
但以上这位同学的方法我使用时出错提示我没有这个属性
model_zoo._download_url_to_file(url, os.path.join(dst_path, filename), hash_prefix, True)
AttributeError: module torch.utils.model_zoo has no attribute _download_url_to_file
所以我略加修改以下是我的处理下载过程同样出错的同学可以看看
from urllib.parse import urlparse
import torch
# import re
import os
def download_model(url, dst_path):parts urlparse(url)filename os.path.basename(parts.path)# HASH_REGEX re.compile(r-([a-f0-9]*)\.)# hash_prefix HASH_REGEX.search(filename).group(1)torch.hub.download_url_to_file(url, os.path.join(dst_path, filename))return filenamepath D:\\vscodeProjects\\models
if not (os.path.exists(path)):os.makedirs(path)
urlhttps://download.pytorch.org/models/vgg16-397923af.pth
download_model(url, path) 只是这个下载的速度着实太慢我先放弃了 关于这个权重文件的下载我犯了一点小迷糊。我有点搞不懂为什么费劲巴拉下载这么大个东西然后视频中又仅仅使用vgg16torchvision.models.vgg16()这一句话就完事了。
于是我搜索了一下 在 PyTorch 中许多流行的深度学习模型如 VGG、ResNet、AlexNet 等都有预先训练好的权重文件可供下载。这些权重文件包含了模型在大规模数据集如 ImageNet上训练的参数可以帮助加快模型的收敛速度提升模型的表现。下载预训练模型通常是为了避免从头开始训练模型节省时间和计算资源。torchvision.models 是 PyTorch 提供的一个模块用于加载常见的计算机视觉模型例如 VGG、ResNet、AlexNet 等。这些模型可以通过简单的调用来导入并且可以选择加载预训练的权重。 简而言之权重文件可以简化我们模型的训练过程我们可以通过使用权重文件来直接利用前辈的训练结果稍作修改就可以变成我们自己的东西。
如果只是用vgg16torchvision.models.vgg16()这么一句话来加载网络模型得到的模型只有结构而没有经过训练的过程因此它的权重是初始的。
网络模型的修改
因为官网中提到的VGG模型的官配数据集ImageNet实在是太大了100个G笔记本实在带不了所以还是使用我们之前已经用了很多次的数据集CIFAR10来搞正好可以讲解一下怎样修改网络模型。
原官配数据集非常之大对我一个初学者来说是暂时见过最大的数据集了最终一共分为1000个类。因此这个VGG模型最终输出为1000为了适配于我们这个CIFAR10数据集输出只有10类我们为加载下来的VGG模型添加一个线性层将原本的1000个类最终输出为10类。
from torch import nn
import torchvision
vgg16torchvision.models.vgg16()
train_datatorchvision.datasets.CIFAR10(../dataset,trainTrue,transformtorchvision.transforms.ToTensor())
vgg16.add_module(add_linear,nn.Linear(1000,10))print(vgg16)可以看到最下面就是我们新添加的层 如果我们想添加在classifier这个模型中我们也可以这样写
vgg16.classifier.add_module(add_linear,nn.Linear(1000,10))
同样打印一下看效果 当然如果我们不想添加新的一层我们也可以通过另外的一种方式来将输出从1000改为10 如上图所示已知最后一层是线性层输入4096输出1000那么我们现在直接将最后一个线性层修改输出改成10
vgg16.classifier[6]nn.Linear(in_features4096,out_features10,biasTrue)
看结果 模型的保存和加载
如果我们对网络模型进行了修改或者训练如何将我们自己的模型保存下来呢一共有以下两种方式
vgg16torchvision.models.vgg16()
vgg16.classifier[6]nn.Linear(in_features4096,out_features10,biasTrue)
#保存方式一保存权重文件和模型结构
torch.save(vgg16,vgg16_method1.pth)
#保存方式二官方推荐实际上保存的是权重文件以字典方式存储
torch.save(vgg16.state_dict(),vgg16_method2.pth)
而如果我们想要取出我们已经保存的模型就可以
#方式一加载保存的模型
vgg16_method1torch.load(vgg16_method1.pth)
#方式二加载保存的权重文件
vgg16_method2torch.load(vgg16_method2.pth)
vgg16torchvision.models.vgg16()
vgg16.load_state_dict(vgg16_method2)