我创建了一个函数,它以x、y、批大小为输入,并生成带有cython的小批处理,以加快处理速度。
import numpy as np
cimport cython
cimport numpy as np
ctypedef np.float64_t DTYPE_t
@cython.boundscheck(False)
def create_mini_batches(np.ndarray[DTYPE_t, ndim=2] X, np.ndarray[DTYPE_t, ndim=2] y, int batch_size):
cdef int m
cdef double num_of_batch
cdef np.ndarray[DTYPE_t, ndim=2] shuffle_X
cdef np.ndarray[DTYPE_t, ndim=2] shuffle_y
cdef int permutation
X, y = X.T, y.T
m = X.shape[0]
num_of_batch = m // batch_size
permutation = list(np.random.permutation(m))
shuffle_X = X[permutation, :]
shuffle_y = y[permutation, :]
for t in range(num_of_batch):
mini_x = shuffle_X[t * batch_size: (t + 1) * batch_size, :]
mini_y = shuffle_y[t * batch_size: (t + 1) * batch_size, :]
yield (mini_x.T, mini_y.T)
if m % batch_size != 0:
mini_x = shuffle_X[m // batch_size * batch_size: , :]
mini_y = shuffle_y[m // batch_size * batch_size: , :]
yield (mini_x.T, mini_y.T)
当我用此代码python setup.py build_ext --inplace
编译程序时,会出现以下错误。
@cython.boundscheck(False)
def create_mini_batches(np.ndarray\[DTYPE_t, ndim=2\] X, np.ndarray\[DTYPE_t, ndim=2\] y, int batch_size):
^
test.pyx:8:24: Buffer types only allowed as function local variables
有人能帮我解决这个错误吗?为什么是一个错误?
发布于 2022-12-04 08:27:17
在这种情况下,这是一条令人眼花缭乱的错误消息,但您会得到它,因为它是一个生成器,而不是一个函数。这意味着Cython必须创建内部数据结构,以便在生成器工作时保持生成器状态。
类型化的Numpy数组变量(例如np.ndarray[DTYPE_t, ndim=2]
)是以一种很难正确处理其引用计数的方式实现的。因此,Cython只能将它们作为常规函数中的变量处理。它不能将它们存储在类中,因此不能在生成器中使用它们。
为了解决这个问题,您要么需要放弃输入,要么您应该切换到最近设计得更好的输入内存视图,这样就没有这个限制了。
https://stackoverflow.com/questions/74673759
复制相似问题