因此,我最近完成了一个在java(人工神经网络库)中构建库的迷你批算法。然后,我跟踪训练我的网络来解决2或3的小批大小的异或问题,因为这两种方法的准确性都比我做1(基本上就是SGD)所得到的更差。现在我明白了,我需要在更多的时代里训练它,但我没有注意到在运行时速度有任何提高,从我所读到的情况来看,这是应该发生的。为什么会这样呢?
这是我的代码(Java)
public void SGD(double[][] inputs,double[][] expected_outputs,int mini_batch_size,int epochs, boolean verbose){
//Set verbose
setVerbose(verbose);
//Create training set
TrainingSet trainingSet = new TrainingSet(inputs,expected_outputs);
//Loop through Epochs
for(int i = 0; i<epochs;i++){
//Print Progress
print("\rTrained: " + i + "/" + epochs);
//Shuffle training set
trainingSet.shuffle();
//Create the mini batches
TrainingSet.Data[][] mini_batches = createMiniBatches(trainingSet,mini_batch_size);
//Loop through mini batches
for(int j = 0; j<mini_batches.length;j++){
update_mini_batch(mini_batches[j]);
}
}
//Print Progress
print("\rTrained: " + epochs + "/" + epochs);
print("\nDone!");
}
private Pair backprop(double[] inputs, double[] target_outputs){
//Create Expected output column matrix
Matrix EO = Matrix.fromArray(new double[][]{target_outputs});
//Forward Propagate inputs
feedForward(inputs);
//Get the Errors which is also the Bias Delta
Matrix[] Errors = calculateError(EO);
//Weight Delta Matrix
Matrix[] dCdW = new Matrix[Errors.length];
//Calculate the Deltas
//Calculating the first Layers Delta
dCdW[0] = Matrix.dot(Matrix.transpose(I),Errors[0]);
//Rest of network
for (int i = 1; i < Errors.length; i++) {
dCdW[i] = Matrix.dot(Matrix.transpose(H[i - 1]), Errors[i]);
}
return new Pair(dCdW,Errors);
}
private void update_mini_batch(TrainingSet.Data[] mini_batch){
//Get first deltas
Pair deltas = backprop(mini_batch[0].input,mini_batch[0].output);
//Loop through mini batch and sum the deltas
for(int i = 1; i< mini_batch.length;i++){
deltas.add(backprop(mini_batch[i].input,mini_batch[i].output));
}
//Multiply deltas by the learning rate
//and divide by the mini batch size to get
//the mean of the deltas
deltas.multiply(learningRate/mini_batch.length);
//Update Weights and Biases
for(int i= 0; i<W.length;i++){
W[i].subtract(deltas.dCdW[i]);
B[i].subtract(deltas.dCdB[i]);
}
}发布于 2019-03-06 22:55:40
我的理解是小批次并不是为了加速计算.但实际上允许计算大型数据集。
如果您有1,000,000个示例,计算机要计算向前和向后传递是很困难的,但是传递包含5,000个元素的批处理则更可行。
对于你的情况,我向你推荐两件事
https://datascience.stackexchange.com/questions/46798
复制相似问题