Skip to content

这是各个主干网络分类模型的源码,可以用于训练自己的分类模型。

License

Notifications You must be signed in to change notification settings

bubbliiiing/classification-keras

Repository files navigation

Classification:分类模型在Keras当中的实现


目录

  1. 仓库更新 Top News
  2. 所需环境 Environment
  3. 文件下载 Download
  4. 训练步骤 How2train
  5. 预测步骤 How2predict
  6. 评估步骤 How2eval
  7. 参考资料 Reference

Top News

2022-03:进行了大幅度的更新,支持step、cos学习率下降法、支持adam、sgd优化器选择、支持学习率根据batch_size自适应调整。
BiliBili视频中的原仓库地址为:https://github.com/bubbliiiing/classification-keras/tree/bilibili

2021-01:仓库创建,支持模型训练,大量的注释,多个可调整参数。支持top1-top5的准确度评价。

所需环境

tensorflow-gpu==1.13.1
keras==2.1.5

文件下载

训练所需的预训练权重都可以在百度云下载。
链接: https://pan.baidu.com/s/1lTqDei24Yt0GLJDnjk2uDg
提取码: gpvj

训练所用的示例猫狗数据集也可以在百度云下载。
链接: https://pan.baidu.com/s/1hYBNG0TnGIeWw1-SwkzqpA
提取码: ass8

训练步骤

  1. datasets文件夹下存放的图片分为两部分,train里面是训练图片,test里面是测试图片。
  2. 在训练之前需要首先准备好数据集,在train或者test文件里里面创建不同的文件夹,每个文件夹的名称为对应的类别名称,文件夹下面的图片为这个类的图片。文件格式可参考如下:
|-datasets
    |-train
        |-cat
            |-123.jpg
            |-234.jpg
        |-dog
            |-345.jpg
            |-456.jpg
        |-...
    |-test
        |-cat
            |-567.jpg
            |-678.jpg
        |-dog
            |-789.jpg
            |-890.jpg
        |-...
  1. 之后修改model_data文件夹下的cls_classes.txt,使其也对应自己需要分的类。
  2. 在准备好数据集后,需要在根目录运行txt_annotation.py生成训练所需的cls_train.txt,运行前需要修改其中的classes_path,classes_path需要指向model_data下的txt文件,txt文件中是自己所要去区分的种类,将其修改成自己需要分的类。
  3. 在train.py里面调整自己要选择的网络和权重后,就可以开始训练了!

预测步骤

a、使用预训练权重

  1. 下载完库后解压,model_data已经存在一个训练好的猫狗模型mobilenet025_catvsdog.h5,运行predict.py,输入
img/cat.jpg

b、使用自己训练的权重

  1. 按照训练步骤训练。
  2. 在classification.py文件里面,在如下部分修改model_path、classes_path、backbone和alpha使其对应训练好的文件;model_path对应logs文件夹下面的权值文件,classes_path是model_path对应分的类,backbone对应使用的主干特征提取网络,alpha是当使用mobilenet的alpha值
_defaults = {
    #--------------------------------------------------------------------------#
    #   使用自己训练好的模型进行预测一定要修改model_path和classes_path!
    #   model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
    #   如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
    #--------------------------------------------------------------------------#
    "model_path"    : 'model_data/mobilenet025_catvsdog.h5',
    "classes_path"  : 'model_data/cls_classes.txt',
    #--------------------------------------------------------------------#
    #   输入的图片大小
    #--------------------------------------------------------------------#
    "input_shape"   : [224, 224],
    #--------------------------------------------------------------------#
    #   所用模型种类:
    #   mobilenet、resnet50、vgg16是常用的分类网络
    #--------------------------------------------------------------------#
    "backbone"      : 'mobilenet',
    #--------------------------------------------------------------------#
    #   当使用mobilenet的alpha值
    #   仅在backbone='mobilenet'的时候有效
    #--------------------------------------------------------------------#
    "alpha"         : 0.25
}
  1. 运行predict.py,输入
img/cat.jpg

评估步骤

  1. datasets文件夹下存放的图片分为两部分,train里面是训练图片,test里面是测试图片,在评估的时候,我们使用的是test文件夹里面的图片。
  2. 在评估之前需要首先准备好数据集,在train或者test文件里里面创建不同的文件夹,每个文件夹的名称为对应的类别名称,文件夹下面的图片为这个类的图片。文件格式可参考如下:
|-datasets
    |-train
        |-cat
            |-123.jpg
            |-234.jpg
        |-dog
            |-345.jpg
            |-456.jpg
        |-...
    |-test
        |-cat
            |-567.jpg
            |-678.jpg
        |-dog
            |-789.jpg
            |-890.jpg
        |-...
  1. 在准备好数据集后,需要在根目录运行txt_annotation.py生成训练所需的cls_test.txt,运行前需要修改其中的classes_path,classes_path需要指向model_data下的txt文件,txt文件中是自己所要去区分的种类,将其修改成自己需要分的类。
  2. 之后在classification.py文件里面修改如下部分model_path、classes_path、backbone和alpha使其对应训练好的文件;model_path对应logs文件夹下面的权值文件,classes_path是model_path对应分的类,backbone对应使用的主干特征提取网络,alpha是当使用mobilenet的alpha值
_defaults = {
    #--------------------------------------------------------------------------#
    #   使用自己训练好的模型进行预测一定要修改model_path和classes_path!
    #   model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
    #   如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
    #--------------------------------------------------------------------------#
    "model_path"    : 'model_data/mobilenet025_catvsdog.h5',
    "classes_path"  : 'model_data/cls_classes.txt',
    #--------------------------------------------------------------------#
    #   输入的图片大小
    #--------------------------------------------------------------------#
    "input_shape"   : [224, 224],
    #--------------------------------------------------------------------#
    #   所用模型种类:
    #   mobilenet、resnet50、vgg16是常用的分类网络
    #--------------------------------------------------------------------#
    "backbone"      : 'mobilenet',
    #--------------------------------------------------------------------#
    #   当使用mobilenet的alpha值
    #   仅在backbone='mobilenet'的时候有效
    #--------------------------------------------------------------------#
    "alpha"         : 0.25
}
  1. 运行eval_top1.py和eval_top5.py来进行模型准确率评估。

Reference

https://github.com/keras-team/keras-applications

About

这是各个主干网络分类模型的源码,可以用于训练自己的分类模型。

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages