- 机器学习实战:模型构建与应用
- (美)劳伦斯·莫罗尼
- 1031字
- 2022-06-28 16:15:59
4.2 在Keras模型中使用TFDS
在第2章中你看到了如何使用TensorFlow和Keras创建一个简单的计算机视觉模型,其中使用了Keras内置的数据集(包括Fashion MNIST),代码如下:
![](https://epubservercos.yuewen.com/A03276/23627497809544006/epubprivate/OEBPS/Images/085-1.jpg?sign=1739041914-nkn2I121iJURZEe4tDE18ulMZ7DqJwkm-0-447b3661a98a8a21c3a0227c51923398)
使用TFDS时代码是非常相似的,但需要一些小的改动。Keras数据集提供的是ndarray
类型,可以直接在model.fit
中使用,但是使用TFDS我们需要做一点转换工作:
![](https://epubservercos.yuewen.com/A03276/23627497809544006/epubprivate/OEBPS/Images/086-1.jpg?sign=1739041914-44PGM8Dtq333QsKM87nEn5f7Ll6CDo3o-0-e0826c29b1db067879d81a7331a5216a)
在这个例子中我们使用了tfds.load
,把fashion_mnist
传给它作为想要的数据集。我们知道它包含train
和test
的分割,因此把这些以数组的形式传送过去会返回一个数据集适配器数组(其中包含图像和标签)。在调用tfds.load
的命令中使用tfds.as_numpy
会导致它们返回Numpy数组。指定batch_size=1
会给我们提供所有的数据,指定as_supervised=True
确保我们得到返回的(输入,标签)的元组。
做完这些,我们就有了Keras数据集中几乎同样格式的数据,只有一个改动—TFDS中的形状是(28,28,1),而Keras数据集中的形状是(28,28)。
这意味着代码需要做一些改动来指定输入数据的形状是(28,28,1)而不是(28,28):
![](https://epubservercos.yuewen.com/A03276/23627497809544006/epubprivate/OEBPS/Images/086-2.jpg?sign=1739041914-bh7FNc4ovQkJf7ZrK4BfUkcG7TNL7Sgr-0-20986afb79f63d5acd641a7e5b672103)
对于更复杂的例子,你可以查看第3章中使用的Horses or Humans数据集。它同样可以在TFDS中找到。下面是用它来训练一个模型的完整代码:
![](https://epubservercos.yuewen.com/A03276/23627497809544006/epubprivate/OEBPS/Images/086-3.jpg?sign=1739041914-0iGE6JErkLhu5zsfPmfAsHLZJKhtBzc4-0-f0d65d57c668daaa538171cd33c14a1d)
可以看到,它非常直接:只需要调用tfds.load
,传送给它你想要的分割(在这个例子中是train
),并在模型中使用它。数据被分批处理和重组,以使训练更加有效。
Horses or Humans数据集被分为训练集和测试集,因此如果你在训练过程中想对模型进行验证,可以从TFDS加载一个独立的验证集,代码如下:
![](https://epubservercos.yuewen.com/A03276/23627497809544006/epubprivate/OEBPS/Images/087-1.jpg?sign=1739041914-gdnB4SxjlpoHMjSnZIdNHF8axtOrs5FF-0-f4746dacf946a7f75c2ffa03c8128133)
你将需要对它进行分批,就像你对训练集所做的一样。例如:
![](https://epubservercos.yuewen.com/A03276/23627497809544006/epubprivate/OEBPS/Images/087-2.jpg?sign=1739041914-EB35nMlFWDMEW1q8tOdszq7Q8Gvdg8nN-0-3919fd07c3c3167069070560acaba79a)
在训练的时候,你指定训练数据是这些批次。你还需要明确地设置每一个回合使用的验证步数,否则TensorFlow会抛出一个错误。如果你不确定,可以把它设置为1
:
![](https://epubservercos.yuewen.com/A03276/23627497809544006/epubprivate/OEBPS/Images/087-3.jpg?sign=1739041914-pVnV1eaTVD4HtrO09IMdHcPi3QuRU0bj-0-1b6301eb57c0bab6595b5ed9b2e0d25d)
加载具体的版本
所有存储在TFDS中的数据集都使用MAJOR.MINOR.PATCH编号系统。该系统保证了以下规则。如果PATCH被更新,那么调用返回的数据是相同的,但是底层组织可能已经改变。任何改变对于开发者而言应该是不可见的。如果MINOR被更新,那么数据仍然没有变化,除了在每个记录中有额外的特征(非破坏性改变)。同样,对于任何特定的切片(见4.4节)数据也是相同的,因此记录不会被重新排序。如果MAJOR被更新,那么记录的格式和它们的位置可能会有变化,因此特定的片段可能会返回不同的结果。
当检查数据集时,你会发现有不同的版本可以使用。例如,cnn_dailymail
数据集(https://oreil.ly/673CJ)。如果你不想使用默认版本(3.0.0),而想使用更早的版本(例如1.0.0),可以像这样加载它:
![](https://epubservercos.yuewen.com/A03276/23627497809544006/epubprivate/OEBPS/Images/088-1.jpg?sign=1739041914-F7puHOd2eKoIPA7tBbzXopigWvirBtc5-0-0047f1948eca4d482f7dd5ee059e29d9)
注意,如果你正在使用Colab,那么检查TFDS使用的版本总是一个好主意。在写作本书时,Cload被预先设置为TFDS 2.0,但是TFDS 2.1和之后的版本解决了一些加载数据集的错误(包括cnn_dailymail
),因此确保使用这些版本的其中一个,或者最起码将它们安装到Colab中,而不是依赖默认的版本。