首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何在tf2.2中使用CosineDecayRestarts

如何在tf2.2中使用CosineDecayRestarts
EN

Stack Overflow用户
提问于 2020-08-31 14:14:15
回答 1查看 833关注 0票数 0

以下是我的代码

代码语言:javascript
运行
复制
def make_model(nh, lr_scheduler):

z = L.Input((nh,), name="Patient")
x = L.Dense(100, activation="relu", name="d1")(z)
x = L.Dense(100, activation="relu", name="d2")(x)
#x = L.Dense(100, activation="relu", name="d3")(x)
p1 = L.Dense(3, activation="linear", name="p1")(x)
p2 = L.Dense(3, activation="relu", name="p2")(x)
preds = L.Lambda(lambda x: x[0] + tf.cumsum(x[1], axis=1), 
                 name="preds")([p1, p2])

model = M.Model(z, preds, name="CNN")
#model.compile(loss=qloss, optimizer="adam", metrics=[score])
model.compile(loss=mloss(0.8),
              optimizer=tf.keras.optimizers.Adam(lr=tf.keras.experimental.CosineDecayRestarts(0.1, iters/4, t_mul=2.0, m_mul=1.0, alpha=0.0,
               name=None), beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.01, amsgrad=False),
              metrics=[score])
return model

我只想试一下CosineDecayRestarts lr调度器,同时检查tf2.2API,我认为使用它的方法是正确的,但它给出了一个错误:

enter image description here

有人能帮上忙吗?

EN

回答 1

Stack Overflow用户

发布于 2021-01-04 19:19:53

我也有同样的问题,所以我基于https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py#L611-L734https://www.tensorflow.org/guide/keras/custom_callback自己写的。

代码语言:javascript
运行
复制
from collections.abc import Iterable
from tensorflow.keras.callbacks import *
from tensorflow.keras import backend as K
import tensorflow as tf
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.framework import constant_op
import math


class CosineDecayRestarts(tf.keras.callbacks.Callback):
    def __init__(self, initial_learning_rate, first_decay_steps, alpha=0.0, t_mul=2.0, m_mul=1.0):
        super(CosineDecayRestarts, self).__init__()
        self.initial_learning_rate = initial_learning_rate
        self.first_decay_steps = first_decay_steps
        self.alpha = alpha
        self.t_mul = t_mul
        self.m_mul = m_mul
        self.batch_step = 0

    def on_train_batch_begin(self, step, logs=None):
        if not hasattr(self.model.optimizer, "lr"):
            raise ValueError('Optimizer must have a "lr" attribute.')
        # Get the current learning rate from model's optimizer.
        lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))
        # Call schedule function to get the scheduled learning rate.
        scheduled_lr = self.schedule(self.batch_step, lr)
        # Set the value back to the optimizer before this epoch starts
        tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
        self.batch_step += 1

    def schedule(self, step, lr):
        def compute_step(completed_fraction, geometric=False):
            """Helper for `cond` operation."""
            if geometric:
                i_restart = math_ops.floor(
                  math_ops.log(1.0 - completed_fraction * (1.0 - self.t_mul)) /
                  math_ops.log(self.t_mul))

                sum_r = (1.0 - self.t_mul**i_restart) / (1.0 - self.t_mul)
                completed_fraction = (completed_fraction - sum_r) / self.t_mul**i_restart

            else:
                i_restart = math_ops.floor(completed_fraction)
                completed_fraction -= i_restart

            return i_restart, completed_fraction

        completed_fraction = step / self.first_decay_steps

        i_restart, completed_fraction = control_flow_ops.cond(
          math_ops.equal(self.t_mul, 1.0),
          lambda: compute_step(completed_fraction, geometric=False),
          lambda: compute_step(completed_fraction, geometric=True))

        m_fac = self.m_mul**i_restart
        cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos(
          constant_op.constant(math.pi) * completed_fraction))
        decayed = (1 - self.alpha) * cosine_decayed + self.alpha

        return math_ops.multiply(self.initial_learning_rate, decayed)

它被用作回调(和其他调度器一样)。它似乎工作得很好。希望你会发现它是有用的。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63665686

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档