下面的keras模型在使用用户提供的损失函数进行训练时会产生下面的错误。下面是MWE (只是使用一个虚拟损失函数来演示这个问题):
library(keras)
## Set up the model
in.lay <- layer_input(shape = 2)
out.lay <- layer_dense(in.lay, units = 2, activation = "sigmoid")
model <- keras_model(in.lay, out.lay)
## Compile
NN <- compile(model, optimizer = "adam", loss = function(x, y = out.lay) 1) # dummy loss function; fit() below fails
## Train
n <- 10000
set.seed(271)
data <- matrix(runif(n * 2), ncol = 2) # dummy training data
prior <- qnorm(matrix(runif(n * 2), ncol = 2)) # dummy prior data
fit(NN, x = prior, y = data, batch_size = 1000, epochs = 10) # training (fails with the error below)下面是我得到的错误信息:
Epoch 1/10
Error in py_call_impl(callable, dots$args, dots$keywords) :
  ValueError: in user code:
    /usr/local/tensorflow/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:571 train_function  *
        outputs = self.distribute_strategy.run(
    /usr/local/tensorflow/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:951 run  **
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    /usr/local/tensorflow/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:2290 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    /usr/local/tensorflow/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:2649 _call_for_each_replica
        return fn(*args, **kwargs)
    /usr/local/tensorflow/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:532 train_step  **
        loss = self.compiled_loss(
    /usr/local/tensorflow/lib/python3.8/site-packages/tensorflow/python/keras/engine/compile_util请注意,如果提供了另一个丢失函数(例如,在NN <- compile(model, optimizer = "adam", loss = loss_mean_squared_error)中),那么上面的模型很好地训练了fit()调用。还请注意,我以前可以使用用户提供的丢失函数,并且没有问题,所以这可能是由于R包、TensorFlow、Keras等更新(我不知道)。在debug(keras:::fit.keras.engine.training.Model)调用之前调用fit()显示keras:::fit.keras.engine.training.Model()中的history <- do.call(object$fit, args)失败,但object$fit和args看起来都不可疑。这是我的sessionInfo() in R
sessionInfo()
R version 4.0.4 (2021-02-15)
Platform: x86_64-apple-darwin20.3.0 (64-bit)
Running under: macOS Big Sur 11.5.2
Matrix products: default
BLAS:   /usr/local/R/R-4.0.4_build/lib/libRblas.dylib
LAPACK: /usr/local/R/R-4.0.4_build/lib/libRlapack.dylib
locale:
[1] en_CA.UTF-8/en_CA.UTF-8/en_CA.UTF-8/C/en_CA.UTF-8/en_CA.UTF-8
attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base
other attached packages:
[1] keras_2.3.0.0.9000
loaded via a namespace (and not attached):
 [1] Rcpp_1.0.7       lattice_0.20-41  here_1.0.1       png_0.1-7
 [5] rprojroot_2.0.2  zeallot_0.1.0    rappdirs_0.3.3   grid_4.0.4
 [9] R6_2.5.0         jsonlite_1.7.2   magrittr_2.0.1   tfruns_1.4
[13] whisker_0.4      Matrix_1.3-2     reticulate_1.22  generics_0.1.0
[17] compiler_4.0.4   base64enc_0.1-3  tensorflow_2.6.0发布于 2021-09-19 15:39:29
我意识到,在更新之后,启动R的方式发生了变化,在运行R而不是在所需的Python环境中运行R时,会产生上述错误。在我从运行TensorFlow的Python环境中启动R之后,这个问题完全消失了。
https://stackoverflow.com/questions/69238317
复制相似问题