I’ve been playing with GANs for a few days. They are incredibly cool and incredibly fussy.

The general idea of a GAN (generative adversarial network) if you’re not aware, is that you have 2 neural nets in competition with each other. The first is trained as a generator, and it’s job is to generate new content. You might use it to remove watermarks, to remove creases and discoloration from old photos, or to colorise black and white images, that type of thing. After you train your generator to a very basic (but not awesome) level, the second network is then trained as a basic classifier to determine whether a given image is real or generated. Then you train them in turns so the generator gets better at generating fake content, and the critic (also known as the discriminator) gets better at spotting fake content.

It’s a very simple idea which works very well. My attempts at producing something useful were based on the Fastai lesson 7 superres GAN notebook, which is a good place to start if you want to play with this idea.

The easiest way to train your generator to is to take some nice clean images, and in the now immortal words of Jeremy, “crappify them”. In the lesson notebook he wanted to show how to make higher res images from crappy low-res images, so his crappification function resizes the image (down) with some nasty interpolation, draws a number onto the image and then saves it with a random jpg quality. If you wanted to remove creases from photos you’d have to write a function to put creases into photos so you have your xs (crappified) and ys (original images); If you wanted to colorise images you’d turn them black and white, etc.

I wanted to train a GAN to remove chain-link fence from photos. I haven’t done photography as a hobby for decades but apparently the pain of knowing “that shot would have been perfect if it wasn’t for that poxy fence” never leaves you. So my crappify function used some transparent PNGs of chain link fence and I put fences in front of all the images.

So far so good. You then train a unet with crappified images as inputs and clean ones as targets.

One of the problems with training GANs is the very beginning, when the generator knows nothing about anything, and it flaps around in the dark for some time, but will ultimately eventually figure out how to start generating something vaguely real looking. Fastai gets around this (like most other things) with transfer learning, so you start out with a generator which knows what things look like (because it’s pre-trained on imagenet), which makes this part of the process much quicker.

After 2 epochs the generator does a reasonable job most of the time but as you can see with the second cat, there is still sometimes a lot of ghosting where the fence has been removed.

Next you train a critic on crappified vs clean images and then hand them both to a fastai object which trains the networks in turns and does all the switching back and forth for you. It uses a loss function which is a combination of pixel loss and critic, with the critic making sure that the generated content looks real, and the pixel loss making sure that it looks like the target image. Bear in mind that you could be training this to do something like in-fills by drawing black blocks over part of your photos, so you want the generated content to be contextually suitable. It’s no good having something which looks realistic if it doesn’t look like it’s part of our photo.

Lots of training later and yellow backgrounded kitty is still a problem child but looking better. It’s actually hard to tell on my tablet screen but I think that there is still a fair bit of ghosting in all the generated images but it’s more obvious on a plain background which wouldn’t normally be as much of an issue. I also believe it’s not generalising as well as it might on images it hasn’t seen but I could fix that by adding more variations of fence and training again.

The main problem so far is how long it takes to train. To get to the point you see above took a total of 8 hours training on Colab and it’s not that great. You can speed up the training process by starting on smaller data and then increasing size gradually towards the end. The bad news is that the input and output size of the network is set from the beginning, so if you start on 128×128, that’s what you’re stuck with. (The other bad news is that 8 hours was with starting small).

There is a paper for a system called Anysize GAN which can train and predict on images of different resolutions and dimensions (they no longer have to be square) but at the time or writing it has not yet been accepted and the authors have not released the source code, so that doesn’t help us right now.

My plan was to train my network on 256×256 and then feed larger images through the GAN in tiles, then reassemble them.

Tiling and reassembling works just fine. Training straight onto 256×256 was slooooooow (and frankly I’d have prefered even larger anyway). Maybe if I had a big grunty box of my own it would have been feasible.

So now I’m just down to one problem (apart from the horrible ghosting, shhhhh). 128×128 tiles are so small that it can’t always tell that it’s looking at part of a fence. Doh. I guess I’d better wait for Anysize GAN after all.

So that’s what I managed. Underwelming. I think that if I understood the hyper-params better I could probably squeeze better results out of it, but this was at the end of part 1 of the course, and it was very much a case of “these parameters should work well most of the time but I’m not going to teach you too much about them until later”.

I’m also not sure at this point where I should be poking it. Should I train the gen more at the beginning? Should I train the gan more before I start bumping the size up or should I train for more epochs while I’m bumping the size up? Or am I better off fiddling with hyper params? Once I can answer these questions I’m sure I’ll come back to this to try and improve it.

Appendix:

If you do actually have a GAN which works nicely and you want the tile/untile functions they are here. They handle rectangular images and “odd sized” images which don’t divide evenly by your tile size.

from PIL import Image as PImage
 
def tile_image(img, tile_sz):
  xs = img.width //  tile_sz; x_has_mod = (img.width % tile_sz > 0)
  ys = img.height // tile_sz; y_has_mod = (img.height % tile_sz > 0)
  tiles = [[None] * (xs + x_has_mod) for _ in range(ys + y_has_mod)]
 
  for y in range(ys + y_has_mod):
    if(y+1 > ys): tly = img.height - tile_sz; bry = img.height; 
    else:         tly = tile_sz * y; bry = tly + tile_sz;
 
    for x in range(xs + x_has_mod):
      if(x+1 > xs): tlx = img.width - tile_sz; brx = img.width; 
      else:         tlx = tile_sz * x; brx = tlx + tile_sz; 
 
      tile = img.crop((tlx,tly,brx,bry))
      tiles[y][x] = tile
 
  return tiles
 
def untile_image(tiles, img_width, img_height):
  img = PImage.new("RGB", (img_width, img_height))
  tile_sz = tiles[0][0].width
 
  xs = img.width // tile_sz; x_has_mod = (img.width % tile_sz > 0)
  ys = img.height // tile_sz; y_has_mod = (img.height % tile_sz > 0)
   
  for y in range(ys + y_has_mod):
    if(y+1 > ys): tly = img.height - tile_sz; bry = img.height; 
    else:         tly = tile_sz * y; bry = tly + tile_sz;
 
    for x in range(xs + x_has_mod):
      if(x+1 > xs): tlx = img.width - tile_sz; brx = img.width; 
      else:         tlx = tile_sz * x; brx = tlx + tile_sz; 
 
      img.paste(tiles[y][x], (tlx,tly,brx,bry))
       
  return img
 
# img = PImage.open("test.jpg")
# tiles = tile_image(img, tile_sz=128)
# untile_image(tiles, img.width, img.height)

I use them like this:

import torchvision.transforms as tfms
 
def pil2fast(img):  
  return Image(tfms.ToTensor()(img))
 
def fast2pil(img):
  return tfms.ToPILImage()(img.data).convert("RGB")
 
def predict(fn):
  TSZ = 128
  img = PImage.open(fn)
 
  tiles = tile_image(img, TSZ)
 
  for y in range(len(tiles)):
    for x in range(len(tiles[0])):
      pred,_,_ = learn_gen.predict(pil2fast(tiles[y][x]))
      tiles[y][x] = fast2pil(pred)
 
  return untile_image(tiles, img.width, img.height)