index = 0 defnext_batch(n): global index if index+n>60000: index = 0 xs = np.float32(mnist.train_images()[index:index+n]).reshape((n,28**2))/255 ys = to_onehotv(mnist.train_labels()[index:index+n]) index += n #print(index) return xs,ys
优化后:
1 2 3 4 5 6 7 8 9 10 11 12 13
#将以下代码及部分其它代码单独放在mnistpack.py中 train_xs = np.float32(mnist.train_images()).reshape((60000,784))/255 mnist_train_labels = mnist.train_labels() index = 0 defnext_batch(n): global index if index+n>60000: index = 0 xs = train_xs[index:index+n] ys = to_onehotv(mnist_train_labels[index:index+n]) index += n #print(index) return xs,ys