大家好,又见面了,我是你们的朋友全栈君。
class Mish(nn.Module): @staticmethod def forward(x): return x * F.softplus(x).tanh()
class MemoryEfficientMish(nn.Module): class F(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x)))
@staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] sx = torch.sigmoid(x) fx = F.softplus(x).tanh() return grad_output * (fx + x * sx * (1 – fx * fx))
def forward(self, x): return self.F.apply(x)
第一种方式比较占显存,我是用的yolov4+第一种没有跑起来。第二种,是网上扒的,据说还可以,各位可以试试。。
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/167136.html原文链接:https://javaforall.cn