I created several classes for reading the MNIST handwritten numeric set using Java. Classes can read files after they have been unpacked (unzipped) from files available on the download site. Classes that allow you to read original (compressed) files are part of the small MnistReader project.
These following classes are self-contained (this means that they have no dependencies on third-party libraries) and are mainly located in the Public Domain - this means that they can simply be copied to their own projects. (Attributes are welcome, but not required):
MnistDecompressedReader Class:
import java.io.DataInputStream; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.nio.file.Path; import java.util.Objects; import java.util.function.Consumer; /** * A class for reading the MNIST data set from the <b>decompressed</b> * (unzipped) files that are published at * <a href="http://yann.lecun.com/exdb/mnist/"> * http://yann.lecun.com/exdb/mnist/</a>. */ public class MnistDecompressedReader { /** * Default constructor */ public MnistDecompressedReader() { // Default constructor } /** * Read the MNIST training data from the given directory. The data is * assumed to be located in files with their default names, * <b>decompressed</b> from the original files: * extension) : * <code>train-images.idx3-ubyte</code> and * <code>train-labels.idx1-ubyte</code>. * * @param inputDirectoryPath The input directory * @param consumer The consumer that will receive the resulting * {@link MnistEntry} instances * @throws IOException If an IO error occurs */ public void readDecompressedTraining(Path inputDirectoryPath, Consumer<? super MnistEntry> consumer) throws IOException { String trainImagesFileName = "train-images.idx3-ubyte"; String trainLabelsFileName = "train-labels.idx1-ubyte"; Path imagesFilePath = inputDirectoryPath.resolve(trainImagesFileName); Path labelsFilePath = inputDirectoryPath.resolve(trainLabelsFileName); readDecompressed(imagesFilePath, labelsFilePath, consumer); } public void readDecompressedTesting(Path inputDirectoryPath, Consumer<? super MnistEntry> consumer) throws IOException { String testImagesFileName = "t10k-images.idx3-ubyte"; String testLabelsFileName = "t10k-labels.idx1-ubyte"; Path imagesFilePath = inputDirectoryPath.resolve(testImagesFileName); Path labelsFilePath = inputDirectoryPath.resolve(testLabelsFileName); readDecompressed(imagesFilePath, labelsFilePath, consumer); } public void readDecompressed(Path imagesFilePath, Path labelsFilePath, Consumer<? super MnistEntry> consumer) throws IOException { try (InputStream decompressedImagesInputStream = new FileInputStream(imagesFilePath.toFile()); InputStream decompressedLabelsInputStream = new FileInputStream(labelsFilePath.toFile())) { readDecompressed( decompressedImagesInputStream, decompressedLabelsInputStream, consumer); } } public void readDecompressed( InputStream decompressedImagesInputStream, InputStream decompressedLabelsInputStream, Consumer<? super MnistEntry> consumer) throws IOException { Objects.requireNonNull(consumer, "The consumer may not be null"); DataInputStream imagesDataInputStream = new DataInputStream(decompressedImagesInputStream); DataInputStream labelsDataInputStream = new DataInputStream(decompressedLabelsInputStream); int magicImages = imagesDataInputStream.readInt(); if (magicImages != 0x803) { throw new IOException("Expected magic header of 0x803 " + "for images, but found " + magicImages); } int magicLabels = labelsDataInputStream.readInt(); if (magicLabels != 0x801) { throw new IOException("Expected magic header of 0x801 " + "for labels, but found " + magicLabels); } int numberOfImages = imagesDataInputStream.readInt(); int numberOfLabels = labelsDataInputStream.readInt(); if (numberOfImages != numberOfLabels) { throw new IOException("Found " + numberOfImages + " images but " + numberOfLabels + " labels"); } int numRows = imagesDataInputStream.readInt(); int numCols = imagesDataInputStream.readInt(); for (int n = 0; n < numberOfImages; n++) { byte label = labelsDataInputStream.readByte(); byte imageData[] = new byte[numRows * numCols]; read(imagesDataInputStream, imageData); MnistEntry mnistEntry = new MnistEntry( n, label, numRows, numCols, imageData); consumer.accept(mnistEntry); } } private static void read(InputStream inputStream, byte data[]) throws IOException { int offset = 0; while (true) { int read = inputStream.read( data, offset, data.length - offset); if (read < 0) { break; } offset += read; if (offset == data.length) { return; } } throw new IOException("Tried to read " + data.length + " bytes, but only found " + offset); } }
MnistEntry Class:
import java.awt.image.BufferedImage; import java.awt.image.DataBuffer; import java.awt.image.DataBufferByte; public class MnistEntry { private final int index; private final byte label; private final int numRows; private final int numCols; private final byte[] imageData; MnistEntry(int index, byte label, int numRows, int numCols, byte[] imageData) { this.index = index; this.label = label; this.numRows = numRows; this.numCols = numCols; this.imageData = imageData; } public int getIndex() { return index; } public byte getLabel() { return label; } public int getNumRows() { return numRows; } public int getNumCols() { return numCols; } public byte[] getImageData() { return imageData; } public BufferedImage createImage() { BufferedImage image = new BufferedImage(getNumCols(), getNumRows(), BufferedImage.TYPE_BYTE_GRAY); DataBuffer dataBuffer = image.getRaster().getDataBuffer(); DataBufferByte dataBufferByte = (DataBufferByte) dataBuffer; byte data[] = dataBufferByte.getData(); System.arraycopy(getImageData(), 0, data, 0, data.length); return image; } @Override public String toString() { String indexString = String.format("%05d", index); return "MnistEntry[" + "index=" + indexString + "," + "label=" + label + "]"; } }
The reader can be used to read uncompressed files. The result will be instances of MnistEntry that are passed to the consumer:
MnistDecompressedReader mnistReader = new MnistDecompressedReader(); mnistReader.readDecompressedTraining(Paths.get("./data"), mnistEntry -> { System.out.println("Read entry " + mnistEntry); BufferedImage image = mnistEntry.createImage(); ... });
The MnistReader project contains several examples of how these classes can be used to read compressed or uncompressed data or to generate PNG images from MNIST records.