我试图使用自定义丢失功能,并从简单的MSE开始。不要注意oscillator函数,它只需要创建数据。
import numpy as np
import matplotlib.pyplot as plt
from keras.layers import Input, Dense
from keras.models import Model
import tensorflow as tf
def oscillator(d_, w0_, x):
assert d_ < w0_
w = np.sqrt(w0_**2 - d_**2)
phi = np.arctan(-d_/w)
A = 1/(2*np.cos(phi))
cos = np.cos(phi+w*x)
sin = np.sin(phi+w*x)
exp = np.exp(-d_*x)
return exp*2*A*cos
# PARAMETERS:
np.random.seed(5)
N = 20
epochs = 2000
d, w0 = 2, 20
nn_dim = 64
# DATA:
x = np.linspace(0,1,100)
y = oscillator(d,w0,x)
x_train = np.sort(np.random.uniform(0,0.35,N)[:,np.newaxis], axis=0)
y_train = oscillator(d,w0,x_train)
tf_y = tf.Variable(y_train,dtype=tf.float32)
# LAYERS:
input_layer = Input(shape=(1,))
Layer_1 = Dense(nn_dim, activation="tanh")(input_layer)
Layer_2 = Dense(nn_dim, activation="tanh")(Layer_1)
output_layer = Dense(1)(Layer_2)
model = Model(inputs=input_layer, outputs=output_layer)
loss_func = tf.reduce_mean(tf.math.squared_difference(tf_y,output_layer))
model.compile(optimizer='adam', loss=loss_func, metrics=['mse'])
md = model.fit(x_train,y_train,epochs=epochs,verbose=1)
y_pred = model.predict(x[:,np.newaxis])
# PLOTTING:
fig = plt.figure()
plt.plot(md.history['loss'], label='training')
plt.legend()
plt.figure()
plt.plot(x,y,label="Exact solution")
plt.scatter(x_train,y_train,label="Data",color="orange")
plt.plot(x,y_pred,label="Prediction",linestyle="--",color="red")
plt.legend()
plt.show()上面的代码产生以下错误:TypeError: Keras符号输入/输出不实现TypeError您可能试图将Keras符号输入/输出传递给不注册调度的TF API,从而阻止Keras自动将API调用转换为功能模型中的lambda层。如果尝试直接断言符号输入/输出,也会引发此错误。用退出代码1**完成的处理
问题在loss_func = tf.reduce_mean(tf.math.squared_difference(tf_y,output_layer))中。我认为这是因为tf_y和output_layer的不同维度。如何用output_layer和y手工计算最小均方误差
发布于 2022-07-27 22:50:23
我个人从未见过像这样定义的损失(而且我认为它很难奏效),您通常希望创建一个函数:
def loss_func(tf_y, output_layer):
return tf.reduce_mean(tf.math.squared_difference(tf_y,output_layer))从文件中:
编译
的丢失参数:可以是字符串(丢失函数的名称),也可以是tf.keras.losses.Loss实例。见tf.keras.losses。损失函数可以用签名丢失= fn( y_true,y_pred)调用,其中y_true是基本真值,y_pred是模型的预测
https://stackoverflow.com/questions/73142975
复制相似问题