EBET易博深度学习完整的模型训练(以VAE+Classifier为例)

  新闻资讯     |      2023-10-01 22:18

  模型的主要组成(如图1所示):模型由4层VAE(L4VAE)构成的降维网络+3层全连接层组成的分类器(Classifier)组成,采用torchvision.datasets中自带的数据集MNIST(手写数字识别)。

  模型的主要任务:首先通过VAE的编码器(encoder)将输入数据的维度input_dim(MNIST数据的维度为(28,28),展开为一维后变为(784,))降低到合适的维度latent_dim,然后将维度为latent_dim的数据送入分类器Classifier进行训练(10分类),完成手写数字的识别任务。其中VAE的解码器(decoder)的作用是配合编码器进行训练,以便使得降维后的数据可以较好的表征输入数据。

  评价指标:对于模型的训练和测试结果或效果,可以通过损失函数Loss、准确率Accuracy、精度Precision、F1_score、召回率Recall和AUC曲线进行评估。为简单,本模型的训练和测试结果只以可视化的方式显示模型的损失函数Loss和准确率Accuracy。(注:损失函数的值是一个相对值,是表征模型训练前相比较于训练后的差别程度,因此,损失函数的值并不是越小越好,它只能反应模型是否在进行“学习”,对于模型的性能优良,最终还是需要通过准确率Accuracy等其他评价指标进行评估)

  其中第一行便是创建数据集的例子,第二行则是加载数据集,在第二步会详细介绍。

  我们使用的EBET易博真人是MNIST数据集,在该页面滚动鼠标,找到MNIST数据集,点击左边的名称可以查看MNIST的用法,点击右边的名称则可以了解MNIST数据集的详细信息。这里点击左边的MNIST,出现下面页面(如图3所示)。

  其中各个参数的含义分别为,root:数据集的根目录,输入为一个路径;train:可选参数,bool型,默认为True,表示加载的数据集的应用;download:可选参数,bool型,默认为False,表示是否下载数据集(建议为True);transform:可选参数,表示对数据作何种变换;target_transform:可选参数,表示对数据标签作何种变换。

  可以采用Pytorch自带的损失函数,也可以采用个性化(自己编写)的损失函数。