TensorFlow2.x 实时打印学习率

在TensorFlow2.x中,通过下面代码,显示当前学习率。

1
optimizer._decayed_lr(tf.float32).numpy()

以下代码展示了在TensorFlow2.x中,如何实时打印当前学习率。
这里以bert使用的adamw为例,模型构建部分省略。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import tensorflow as tf
import optimization # bert中的optimization adamw
from matplotlib import pyplot as plt
from tqdm import tqdm
init_lr = 5e-5
num_train_steps = 10000
num_warmup_steps = 1000
optimizer = optimization.create_optimizer(init_lr, num_train_steps, num_warmup_steps)
learn_rate = []
for i in tqdm(range(10000)):
with tf.GradientTape() as tape:
logits = model(data)
loss = tf.losses.sparse_categorical_crossentropy(label, logits, from_logits=True)
grads = tape.gradient(loss, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
learn_rate.append(optimizer._decayed_lr(tf.float32).numpy())
plt.plot(learn_rate)
plt.show()

学习率曲线图

adamw_lr

Reference :
https://stackoverflow.com/questions/58149839/learning-rate-of-custom-training-loop-for-tensorflow-2-0/58151051#58151051