Skip to content

本仓库基于resnet50网络,运用pytorch对cifar10进行分类,可以调用另外两种网络densenet,goooglenet与resnet50网络进行对比

Notifications You must be signed in to change notification settings

crabbit37927/pytorch-resnet50-with-cifar10

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

pytorch-resnet50 and cifar-10

写在前面

仓库持有者是SZU22级交通专业学生,此仓库用于记录我的学习生活——to crabbit

概述

通过pytorch里面的resnet50模型实现对cifar-10数据集的分类,并将混淆矩阵和部分特征图可视化。

最终测试集的准确率达到95%以上。

辅以另外两个网络googlenet和densenet与resnet50进行对比

环境配置

python版本3.9

cuda版本11.7

pytorch版本2.0.0+cu117

torchvision版本0.15.2+cu117

torchaudio版本0.13.1+cu117

如果你还没安装环境,请一定要确定好cuda,python,以及torch这三者之间的版本依赖,这很重要!

代码组成

  • main.py

  • 用于运行主程序

  • get_data.py

  • 用于下载cifar-10数据,并对数据进行预处理

    可以在此文件中更改self.train_batch_size与self.test_batch_size为你需要的batch size

  • modle.py

  • 用于训练与测试模型

    可以在此文件中更改self.epoch为你需要的epoch数

    可以更改self.optimizer中的lr为你需要的学习率

  • visual.py

  • 进行所有可视化操作,包括训练过程中acc-loss的变化曲线,测试集的混淆矩阵

    请注意:该项目中的混淆矩阵对y轴进行了翻转操作

    当模型选择为resnet时,会在模型训练后挑选十张图片,输出原图与预测结果,如果预测正确,字体颜色为绿,反之为红。

    此外,还会挑选一张照片,显示模型训练完毕后,在每一层的部分特征图,类似下图:

写在最后

本人非计软学生,python基本都是自学,水平有限,欢迎各位指正我代码的问题,对此项目有疑问也欢迎来问我。联系方式[email protected]

About

本仓库基于resnet50网络,运用pytorch对cifar10进行分类,可以调用另外两种网络densenet,goooglenet与resnet50网络进行对比

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages