TensorFlow 2.0重大更新
TensorFlow 2.0是TF史上最重要的升级,带来了诸多革命性变化。
核心改进
1. Keras集成
2.0将Keras作为官方高级API:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| import tensorflow as tf
model = tf.keras.Sequential([ tf.keras.layers.Dense(256, activation='relu', input_shape=(784,)), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10, activation='softmax') ])
model.compile( optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'] )
model.fit(train_data, epochs=10)
|
2. Eager Execution
默认启用动态图,调试更直观:
1 2 3
| result = tf.constant([[1, 2], [3, 4]]) + tf.constant([[5, 6], [7, 8]]) print(result.numpy())
|
3. tf.function装饰器
将Python代码编译为高性能图:
1 2 3 4 5 6 7
| @tf.function def train_step(images, labels): with tf.GradientTape() as tape: predictions = model(images) loss = loss_object(labels, predictions) gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables))
|
4. 统一数据管道
tf.data提供高效数据处理:
1 2 3 4
| dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) dataset = dataset.shuffle(10000) dataset = dataset.batch(32) dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
迁移指南
| TF 1.x |
TF 2.0 |
| tf.Session |
Eager Execution |
| tf.placeholder |
Keras Input |
| tf.global_variables_initializer() |
无需 |
| tf.contrib |
tf-addons |
总结
TensorFlow 2.0大大降低了学习门槛,同时保持了生产级别的性能。