在TensorFlow2.x中,通过下面代码,显示当前学习率。
1
| optimizer._decayed_lr(tf.float32).numpy()
|
以下代码展示了在TensorFlow2.x中,如何实时打印当前学习率。
这里以bert使用的adamw为例,模型构建部分省略。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| import tensorflow as tf import optimization from matplotlib import pyplot as plt from tqdm import tqdm init_lr = 5e-5 num_train_steps = 10000 num_warmup_steps = 1000 optimizer = optimization.create_optimizer(init_lr, num_train_steps, num_warmup_steps) learn_rate = [] for i in tqdm(range(10000)): with tf.GradientTape() as tape: logits = model(data) loss = tf.losses.sparse_categorical_crossentropy(label, logits, from_logits=True) grads = tape.gradient(loss, model.trainable_weights) optimizer.apply_gradients(zip(grads, model.trainable_weights)) learn_rate.append(optimizer._decayed_lr(tf.float32).numpy()) plt.plot(learn_rate) plt.show()
|
学习率曲线图
Reference :
https://stackoverflow.com/questions/58149839/learning-rate-of-custom-training-loop-for-tensorflow-2-0/58151051#58151051