import PIL.Image as Image import PIL.ImageDraw as ImageDraw import colorsys import random from pybrain.tools.shortcuts import buildNetwork from pybrain.datasets import SupervisedDataSet from pybrain.supervised.trainers import BackpropTrainer import pybrain.structure.modules import numpy import texture import sys class StatusHandler: def __init__(self): self.last_status = None self.taskname = None def done(self): if self.taskname != None: print '\r' + self.taskname + '... done' self.taskname = None def newtask(self, name): if self.taskname != None: self.done() self.taskname = name self.last_status = 0 self.print_line() def update(self, current, maximum): status = int((100 * current) / maximum) if self.last_status != status: self.last_status = status self.print_line() def info(self, message): self.done() print message def print_line(self): print '\r' + self.taskname + '...', print '%d%%' % self.last_status, sys.stdout.flush() class PixelClassifier: def __init__(self, image, block_size = 15): self.set_image(image) self.block_size = block_size self.halfsize = self.block_size // 2 self.build_kernels() self.input_count = len(self.get_sample(0, 0)) self.net = buildNetwork(self.input_count, self.input_count, 1) self.status = StatusHandler() def set_image(self, image): self.image = image self.grayimage = image.convert('L') self.width, self.height = self.image.size def make_kernel(self, *args, **kwargs): kwargs['block_size'] = self.block_size return texture.make_kernel(*args, **kwargs) def build_kernels(self): self.kernels = [ # Round kernels self.make_kernel((0, 1), (1, 0)), self.make_kernel((0, 2), (2, 0)), self.make_kernel((0, 3), (3, 0)), self.make_kernel((0, 1), (1, 0), total_r = 3) ] for i, kernel in enumerate(self.kernels): image = texture.kernel_to_image(kernel) image.save('kernel%02d.png' % i) def get_sample(self, x, y): r, g, b = self.image.getpixel((x,y))[:3] features = [r / 255., g / 255., b / 255.] x = max(self.halfsize, x) x = min(self.width - self.halfsize, x) y = max(self.halfsize, y) y = min(self.height - self.halfsize, y) coords = (x - self.halfsize, y - self.halfsize, x + self.halfsize + 1, y + self.halfsize + 1) tile = self.grayimage.crop(coords) extremas = tile.getextrema() features.append((extremas[1] - extremas[0]) / 255.) array = numpy.array(tile) for kernel in self.kernels: value = abs((kernel * array).sum()) features.append(value / 255.) return tuple(features) def train(self, mask, step = 8, passes = 3): assert mask.size == self.image.size self.status.newtask("Collecting training samples") dataset = SupervisedDataSet(self.input_count, 1) for y in range(0, self.height, step): self.status.update(y, self.height) for x in range(0, self.width, step): maskval = mask.getpixel((x, y)) if maskval == (0, 255, 0): output = 1 elif maskval == (255, 0, 0): output = -1 else: continue sample = self.get_sample(x, y) dataset.addSample(sample, output) self.status.done() self.status.info("Got %d samples" % len(dataset)) trainer = BackpropTrainer(self.net, dataset) for i in range(passes): self.status.newtask("Training %d" % i) self.status.info("Error %0.3f" % trainer.train()) def classify(self, mask = None, step = 8): if mask: assert mask.size == self.image.size mask = mask.convert('1') dataset = SupervisedDataSet(self.input_count, 1) coords = [] self.status.newtask("Collecting samples") for y in range(0, self.height, step): self.status.update(y, self.height) for x in range(0, self.width, step): if mask and not mask.getpixel((x,y)): continue sample = self.get_sample(x, y) coords.append((x, y)) dataset.addSample(sample, 0) self.status.newtask("Classifying samples") output = self.net.activateOnDataset(dataset) self.status.newtask("Drawing output image") image = Image.new('L', self.image.size) draw = ImageDraw.Draw(image) for (value,), (x, y) in zip(output, coords): draw.rectangle((x, y, x + step, y + step), fill = int(value * 127 + 128)) self.status.done() return image if __name__ == '__main__': if len(sys.argv) != 5: print "Usage:", sys.argv[0], print "train.png trainmask.png test.png output.png" sys.exit(1) train = Image.open(sys.argv[1]) mask = Image.open(sys.argv[2]) test = Image.open(sys.argv[3]) classifier = PixelClassifier(train) classifier.train(mask) classifier.set_image(test) output = classifier.classify() output.save(sys.argv[4])