Applying R’s imager library to MNIST digits
Introduction
For machine learning, when it comes to training data, the more the better! Well, there are a whole bunch of disclaimers on that statement, but as long as the data is representative of the type of data the machine learning model will be used on in production, doubling the amount of unique training data will be more useful than doubling the number of epochs (if deep learning) or trees (if random forest, GBM, etc.).I’m going to concentrate in this article on the MNIST data set, and look at how to use R’s imager library to increase the number of training samples. I take a deep look at MNIST in my new book, Practical Machine Learning with H2O, which (not surprisingly) I highly recommend. But, briefly, the MNIST data is handwritten digits, and the challenge is to identify which of the 10 digits, 0 to 9, each one is.
There are 60,000 training sample, plus 10,000 test samples (and everyone uses the same test set - so watch out for the inadvertent over-fitting in published papers). The samples are 28x28 pixels, each a 0 to 255 greyscale value. I usually split the 60,000 samples into 50K for training, 10K for validation (to make sure I am not over-fitting on the test data).
imager and parallel
I used this R library for dealing with images: http://dahtah.github.io/imager/ which is based on a C++ library called CImg. Most function calls are just wrappers around the C++ code, which means they are fairly quick. It is well-documented with a good starter tutorial.I used version 0.20 for all my development. I have just seen that 0.30 is now out, and in particular offers native parallelization. This is great! Out of scope for this article, but I used imager in conjunction with R’s parallel functions, and found the latter quite clunky, with time spent copying data structures limiting scalability. On the other hand, the docs say these new parallel options work best on large images, and the 28x28 MNIST images are certainly not that. So maybe I am still stuck using
parApply()
and friends.The Approach
In the 20,000 unseen samples (10K valid, 10K test), there are often examples of bad handwriting that we don’t get to see in our 50,000 training samples. Therefore I am most interested in generated bad handwriting samples, not nice neat ones.I spent a lot of time experimenting with what imager can produce, and settled on generating these effects, each with a random element:
- rotate
- warp (make it “scruffier”)
- shift (move it 1 pixel up, down, left or right)
- bold (make it fatter - my code)
- dilate (make it fatter - cimg code)
- erode (make it thinner)
- erodedilate (one or the other)
- scratches (add lines)
- blotches (remove blobs)
In the full code I create the image in an imager (
cimg
) object called im
, then copy it to im2
. Each subsequent operation is performed on im2
. im
is left unchanged, but can be referred to for the initial state.Rotate
The code to rotate comes in two parts. Here is the first part:needSharpen = FALSE
angle = rnorm(1, 0, 8)
if(angle < -1.0 || angle > 1.0){
im2 = imrotate(im2, angle, interpolation = 2)
nPix = (width(im2) - width(im)) / 2
im2 = crop.borders(im2 , nPix = nPix)
needSharpen = TRUE
}
The use of rnorm(sd=8)
, means 68% of time it’ll be +/-8°, only 5% of the time more than +/-16°. If my goal was simply more training samples, I’d have perhaps used as smaller sd
, and/or clipped to a maximum rotation of 10°. But, as mentioned earlier, I wanted more scruffy handwriting. The if()
block is a CPU optimization - if rotation is less than 1° don’t bother doing anything.The
imrotate()
command takes the current im2
and replaces it with one that is rotated. This creates a larger image. To see what is going on, try running this (self-contained) script (see the inline comments for what is happening):library(imager)
# Make 28x28 "mid-grey" square
im <- as.cimg(rep(128, 28*28), x = 28, y = 28)
#Prepare to plot side-by-side
par(mfrow = c(1,2))
#Show initial square 28x28
plot(im)
#Show rotated square, 34x34
plot(imrotate(im, angle = 16))
The output is like this:You can see the image has become larger, to contain the rotated version. (That image also shows how imager’s plot command will scale the colours based on the range, and that my choice of 128 could have been any non-zero number. When there is only a single value (128), it chooses a grey. After rotating we have 0 for the background, 128 for the square, so it does 0 as black, 128 as white.)
For rotating MNIST digits:
- I want to keep the 28x28 size
- All the interesting content is in the middle, so clipping is fine.
crop.borders()
, which takes an argument nPix
saying how many pixels to remove on each side. If it has grown from 28 to 34 pixels square, nPix will be 3.Feeling A Bit Vague…
Here is what one of the MNIST digits looks like rotated 30° at a time, 11 times.In a perfect world, 12 rotations would give you exactly the image you started with. But you can see the effect of each rotation is to blur it slightly. If we did another lap, even your clever mammalian brain would no longer recognize it as a 4.
The author of the imager library provided a
match.hist()
function (see it, and the surrounding discussion, here: https://github.com/dahtah/imager/issues/17 ) which does a good (but not perfect) job. Here are the histograms of the image before rotation, after rotation, and then after match.hist
:You can judge the success from looking at the image on the right, or by seeing how the bars on the rightmost histogram match those of the leftmost one. (Yes, even though the bumps are very small, their height, and where they are, really matter!)
You better sharpen up…
You would have noticed the earlier rotate code setneedSharpen
to true. That is used by the following code. Some of the time it uses the imager library’s imsharpen()
, some of the time match.hist()
, and some of the time a hack I wrote to make dim pixels dimmer, and bright pixels brighter.if(needSharpen){
if(runif(1) < 0.8){
im2 <- imsharpen(im2, amplitude = 55)
}
if(runif(1) < 0.3){
im2 <- ifelse(im2 < 128, im2-16, im2)
im2 <- ifelse(im2 < 0, 0, im2)
im2 <- ifelse(im2 > 200, im2+8, im2)
im2 <- ifelse(im2 > 150, im2+8, im2)
im2 <- ifelse(im2 > 100, im2+8, im2)
im2 <- ifelse(im2 > 255, 255, im2)
}else{
im2 <- match.hist(im2, im)
}
}
The Others
The other image modifications, listed earlier, useimwarp()
, imshift()
, pmax()
with imshift()
(for a bold effect), dilate_square()
, and erode_square()
. The blotches and scratches were done by putting random noise on an image, then using pmax()
or pmin()
to combine them.If there is interest I can write another article going into the details.
Timings
On a 2.8GHz single-core, I recorded these timings to process 60,000 28x28 images. (It was a 36-core 8xlarge EC2 machine, but my R script, and imager (at the time), only used one thread.)- 304s to run “bold”.
- 296s to run “shift”
- 417s for warp
- 447s for rotate
- 517s to 539s for all and all2
Summary
I made 20 files, so over 95% of my training data was generated. As you will discover if you read the book, this generated data gave a very useful boost in model strength, though of course dramatically increased learning-time due to having 1.2 million training rows instead of 50,000. An interesting property was that it found the sampled data harder to learn: I got lower error rates on the unseen valid and test data sets than on the seen training data. This is a consequence of my deliberate decision to bias towards noisy data and scruffy handwriting.Generating additional training data is a good way to prevent over-fitting, but generating data that is still representative is always a challenge. Normally you’d check mean, sd, the distribution, etc. to see how you’ve done, and usually your only technique is to jitter - add a bit of random noise. With images you have a rich set of geometric transformations to choose from, and you often use your eyes to see how it has done. Though, as we saw from the image histograms, there are some objective techniques available too.
No comments:
Post a Comment