首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >训练好模型之后,进行预测时出现NotImplementedError怎么办?

训练好模型之后,进行预测时出现NotImplementedError怎么办?

提问于 2022-10-23 19:38:39
回答 0关注 0查看 100

预测函数部分

test_gen = generator(list(yields['test'].keys()), yields['test'], 16)
X_test, y_test= next(test_gen)

IDS = list(yields['test'].keys())
sum = 0
for i in IDS:
    sum += yields['test'][i]
avg = sum / len(IDS)
print("Average Test Yield is ", avg)

model = load_model('CNN_LSTM_AVG_1000')

print(len(X_test),len(y_test))
a, b = model.evaluate(X_test, y_test, batch_size=16)
y_pred = model.predict(X_test, batch_size=16, verbose=1)

generator函数处理数据

def generator(IDs, yields, batch_size, cutoff=None):
    import numpy as np
    import random

    # Create empty arrays to get batch of features and labels

    if cutoff != None:
        batch_features = np.zeros((batch_size, cutoff, 1, 256, 10))
        batch_yields = np.zeros((batch_size))
        while True:
            for i in range(batch_size):
                # choose random index in features
                index = random.choice(range(len(IDs)))
                ID = IDs[index]
                if np.sum(np.isnan(np.load('Data/PROCESSED_III/' + ID + '.npy'))) == 0:
                    batch_features[i, :, :, :, :] = np.load('Data/PROCESSED_III/' + ID + '.npy')[:cutoff, :, :, :]
                    # print('yes', ID)
                    batch_yields[i] = yields[ID]
                else:
                    print('no', ID)

            yield batch_features, batch_yields

    else:
        batch_features = np.zeros((batch_size, 14, 1, 32, 4))
        batch_yields = np.zeros((batch_size))
        while True:
            for i in range(batch_size):
                # choose random index in features
                index = random.choice(range(len(IDs)))
                ID = IDs[index]
                if np.sum(np.isnan(np.load('Data/data_1/' + ID + '.npy'))) == 0:
                    batch_features[i, :, :, :, :] = np.load('Data/data_1/' + ID + '.npy')
                    # print('yes', ID)
                    batch_yields[i] = yields[ID]
                else:
                    print('no', ID)
            yield batch_features, batch_yields

报错

Traceback (most recent call last):
  File "F:/Crop-Yield-Prediction-Using-CNN-LSTM/Crop-Yield-Prediction-Using-CNN-LSTM--Temp/Code/Train_HB.py", line 174, in <module>
    a, b = model.evaluate(X_test, y_test, batch_size=16)
  File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\keras\engine\training.py", line 66, in _method_wrapper
    return method(self, *args, **kwargs)
  File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1081, in evaluate
    tmp_logs = test_function(iterator)
  File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\eager\def_function.py", line 580, in __call__
    result = self._call(*args, **kwds)
  File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\eager\def_function.py", line 627, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\eager\def_function.py", line 506, in _initialize
    *args, **kwds))
  File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\eager\function.py", line 2446, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\eager\function.py", line 2777, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\eager\function.py", line 2667, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\framework\func_graph.py", line 981, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\eager\def_function.py", line 441, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\framework\func_graph.py", line 968, in wrapper
    raise e.ag_error_metadata.to_exception(e)
NotImplementedError: in user code:

求求解答

回答

和开发者交流更多问题细节吧,去 写回答
相关文章

相似问题

相关问答用户
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档