PrettyTensor
PrettyTensor 在 TensorFlow 上提供了一个薄包装器。 PrettyTensor 提供的对象支持可链接的语法来定义神经网络。例如,可以通过链接层来创建模型,如以下代码所示:
model = (X.
flatten().
fully_connected(10).
softmax_classifier(n_classes, labels=Y))
可以使用以下命令在 Python 3 中安装 PrettyTensor:
pip3 install prettytensor
PrettyTensor 以名为apply()
的方法提供了一个非常轻量级和可扩展的界面。可以使用.apply(function, arguments)
方法将任何附加函数链接到 PrettyTensor 对象。 PrettyTensor 将调用function
并提供当前张量作为function
的第一个参数。
User-created functions can be added using the @prettytensor.register
decorator. Details can be found at https://github.com/google/prettytensor.
在 PrettyTensor 中定义和训练模型的工作流程如下:
- 获取数据。
- 定义超参数和参数。
- 定义输入和输出。
- 定义模型。
- 定义评估程序,优化程序和训练器函数。
- 创建跑步者对象。
- 在 TensorFlow 会话中,使用
runner.train_model()
方法训练模型。 - 在同一会话中,使用
runner.evaluate_model()
方法评估模型。
笔记本ch-02_TF_High_Level_Libraries
中提供了 PrettyTensor MNIST 分类示例的完整代码。 PrettyTensor MNIST 示例的输出如下:
[1] [2.5561881]
[600] [0.3553167]
Accuracy after 1 epochs 0.8799999952316284
[601] [0.47775066]
[1200] [0.34739292]
Accuracy after 2 epochs 0.8999999761581421
[1201] [0.19110668]
[1800] [0.17418651]
Accuracy after 3 epochs 0.8999999761581421
[1801] [0.27229539]
[2400] [0.34908807]
Accuracy after 4 epochs 0.8700000047683716
[2401] [0.40000191]
[3000] [0.30816519]
Accuracy after 5 epochs 0.8999999761581421
[3001] [0.29905257]
[3600] [0.41590339]
Accuracy after 6 epochs 0.8899999856948853
[3601] [0.32594997]
[4200] [0.36930788]
Accuracy after 7 epochs 0.8899999856948853
[4201] [0.26780865]
[4800] [0.2911002]
Accuracy after 8 epochs 0.8899999856948853
[4801] [0.36304188]
[5400] [0.39880857]
Accuracy after 9 epochs 0.8999999761581421
[5401] [0.1339224]
[6000] [0.14993289]
Accuracy after 10 epochs 0.8899999856948853