mikemoke blog

データ解析やってます。統計・機械学習・画像解析など。

deeplearning(chainer)で超解像やってみた

最近、waifu2xというソフトウェアが話題になっています。ultraist.hatenablog.com

画像拡大後、補正をかけることにより輪郭をシャープに見せるほか、ノイズを除去等できるようです。
ConvolutionalNeuralNetを適用することで実現しているようで、参考にされた論文はこちら。↓
Chao Dong, Chen Change Loy, Kaiming He, Xiaoou Tang, "Image Super-Resolution Using Deep Convolutional Networks"
http://arxiv.org/abs/1501.00092

Deepとは言いつつも、CNN3層とネットワーク構造が簡単で、画像が小さければCPUでも計算できそう?
ちょうどchainerを使ってみたいという気持ちがあったので、練習がてら簡単に実装してみました。
Deep Learning のフレームワーク Chainer を公開しました | Preferred Research


[学習について]
データのカテゴリについて:
画像のカテゴリによってネットワーク構造が変化すると考えられるため、カテゴリを統一しました。
今回は「森」と考えられる画像を適用します。(主観で)
画像サイズについて:
64✕64の画像を適用します。256✕256についても動かしてみましたが、CPUではウンとも動かなかったため断念。
画像データから、64✕64のROIを取得し、これを学習データとしました。
入出力データについて:
ネットワークへの入力には低画質化した画像を、出力にはオリジナルの画像を適用します。
低画質化の方法ですが、とりあえずscikit-imageパッケージのrescale関数を用いて
0.5倍→2.0倍とすることで低画質画像を作成しました。
色空間について:
RGB色空間のまま、補正ネットワークを学習します。
論文ではYCbCr空間に変換し、Yのみをネットワークにより補正しているようですが、今回はとりあえず雰囲気を見るということで。


[評価]
検証用画像を用意して、オリジナル・低画質化画像・補正結果を比較します。
f:id:mikemoke:20150628161015p:plainf:id:mikemoke:20150628163606p:plainf:id:mikemoke:20150628163628p:plainf:id:mikemoke:20150628163654p:plain オリジナル

f:id:mikemoke:20150628160833p:plainf:id:mikemoke:20150628163710p:plainf:id:mikemoke:20150628163722p:plainf:id:mikemoke:20150628163731p:plain 低画質化画像

f:id:mikemoke:20150628160848p:plainf:id:mikemoke:20150628163743p:plainf:id:mikemoke:20150628163748p:plainf:id:mikemoke:20150628163757p:plain train epoch : 200

f:id:mikemoke:20150628160909p:plainf:id:mikemoke:20150628163808p:plainf:id:mikemoke:20150628163813p:plainf:id:mikemoke:20150628163821p:plain train epoch : 800

f:id:mikemoke:20150628160936p:plainf:id:mikemoke:20150628163833p:plainf:id:mikemoke:20150628163840p:plainf:id:mikemoke:20150628163847p:plain train epoch : 78000

epoch数200では学習が不十分で、入力よりも更に低画質化されています。
epoch数が増えるにつれ高画質化され、epoch数78000では入力画像よりも画質が向上しているように見えます。
(定量的な評価はしていません、、目視です)

画像が小さくてよくわからなかったので、大きな画像を64✕64のROIに分割し高画質化しました。
なんとなく画質が向上しているようです。
f:id:mikemoke:20150628165627p:plainf:id:mikemoke:20150628165638p:plain


[所感]
SRCNN(SuperResolutionConvolutionalNeuralNet)について
効果を何となく確認することが出来ました。
やはり学習データが重要で、高解像度化したい対象に合わせてデータセットを準備しなければならないという印象。(waifu2xと同様に、今回学習したモデルをイラストに適用したところ、とても残念なことに)
低画質化のプロセスも重要で、除去したい要因(拡大、ピンぼけ、収差、etc)を考慮して、低画質化画像を作成する必要がありそう。
chianerについて
今までpylearn2しか使ったことがありませんが、chainerの方が便利に感じました。
データ構造やライブラリが理解しやすく、任意のネットワーク構築も比較的簡単にできそう。
今回も、なんとなく実装しましたが動きました。
DeepLearning実装の敷居が下がったのではないかと感じます。


[ソース]
(chainerのexampleを参考にしました。chainer/train_imagenet.py at master · pfnet/chainer · GitHub
(ザクっと作ったので、間違いがありましたらすみません。)

import cPickle as pickle
from datetime import timedelta
import json
from multiprocessing import Pool
from Queue import Queue
import random
import sys
from threading import Thread
import time
import skimage.io
import skimage.transform
import numpy as np
import copy
import csv

from chainer import cuda, Variable, FunctionSet, optimizers
import chainer.functions  as F

class SRCNN(FunctionSet):
    insize = 64
    outputsize = 64-(9-1)-(1-1)-(5-1)
    
    def __init__(self):
        super(SRCNN, self).__init__(
            conv1=F.Convolution2D(3,64,9),
            conv2=F.Convolution2D(64,32,1),
            conv3=F.Convolution2D(32,3,5),
        )
        
    def forward(self, x_data, y_data, train=True):
        x = Variable(x_data, volatile=not train)
        t = Variable(y_data, volatile=not train)

        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h))
        h = F.relu(self.conv3(h))

        loss = F.mean_squared_error(h, t)

        return loss, h


#input setting
CONST_IMG_SIZE = 64
CONST_N_TRAIN = 10000
CONST_N_EVAL = 1000
CONST_PATH_TRAIN = "./output/train.csv"
CONST_PATH_EVAL = "./output/eval.csv"
CONST_N_BATCH = 50
CONST_N_BATCH_EVAL = 250
CONST_N_EPOCH = 1000
CONST_GPU_ID = -1 #cpu:-1,gpu:0~
CONST_N_LOADER = 1

#initial setting
#craete model instance
model = SRCNN()
if CONST_GPU_ID >= 0:
    cuda.init(CONST_GPU_ID)
    model.to_gpu()

# Setup optimizer
optimizer = optimizers.MomentumSGD(lr=0.01, momentum=0.9)
optimizer.setup(model.collect_parameters())

# Prepare dataset
def load_image_list(path):
    pathlist = []
    csvf = open(path)
    reader = csv.reader(csvf)
    for line in reader:
        pathlist.append(line[0])
    csvf.close()
    return pathlist

train_list = load_image_list(CONST_PATH_TRAIN)
val_list   = load_image_list(CONST_PATH_EVAL)

# ------------------------------------------------------------------------------
# This example consists of three threads: data feeder, logger and trainer. These
# communicate with each other via Queue.
data_q = Queue(maxsize=1)
res_q  = Queue()

# Data loading routine
def read_image(path):
    image = skimage.io.imread(path,as_gray=False)
    v_rand = random.randint(0,3)
    if v_rand == 0:
        image = image[:,:,:]
    elif v_rand == 1:
        image = image[::-1,:,:]
    elif v_rand == 2:
        image = image[:,::-1,:]
    elif v_rand == 3:
        image = image[::-1,::-1,:]
    #create output data
    output_top = (CONST_IMG_SIZE- model.outputsize) / 2
    output_bottom = CONST_IMG_SIZE- output_top
    image_out = image[output_top:output_bottom,output_top:output_bottom,:]
    image_out = image_out.transpose(2, 0, 1).astype(np.float32)/255.
    #create input coarse data
    image_in = skimage.transform.rescale(image,0.5)
    image_in = skimage.transform.rescale(image_in,2.)#/255 already 0~1
    image_in = image_in.transpose(2, 0, 1).astype(np.float32)
    
    return image_in,image_out

def eval_image_show(y,predict,name):
    y_convert = y.transpose(1, 2, 0)
    predict_convert = predict.transpose(1, 2, 0)
    #create coarse data
    x_convert = skimage.transform.rescale(y_convert,0.5)
    x_convert = skimage.transform.rescale(x_convert,2.)
    
    #convert for output
    y_convert[np.where(y_convert>1.0)] = 1.0
    x_convert[np.where(x_convert>1.0)] = 1.0
    predict_convert[np.where(predict_convert>1.0)] = 1.0
    y_convert = (y_convert*255).astype(np.uint8)
    x_convert = (x_convert*255).astype(np.uint8)
    predict_convert = (predict_convert*255).astype(np.uint8)

    skimage.io.imsave(name+"in.png",x_convert)
    skimage.io.imsave(name+"out.png",y_convert)
    skimage.io.imsave(name+"pre.png",predict_convert)

# Data feeder
def feed_data():
    i     = 0
    count = 0

    x_batch = np.ndarray((CONST_N_BATCH, 3, CONST_IMG_SIZE, CONST_IMG_SIZE), dtype=np.float32)
    y_batch = np.ndarray((CONST_N_BATCH,3, model.outputsize, model.outputsize), dtype=np.float32)
    val_x_batch = np.ndarray((CONST_N_BATCH_EVAL, 3, CONST_IMG_SIZE, CONST_IMG_SIZE), dtype=np.float32)
    val_y_batch = np.ndarray((CONST_N_BATCH_EVAL,3, model.outputsize, model.outputsize), dtype=np.float32)

    batch_pool     = [None] * CONST_N_BATCH
    val_batch_pool = [None] * CONST_N_BATCH_EVAL
    pool           = Pool(CONST_N_LOADER)
    data_q.put('train')
    for epoch in xrange(1, 1 + CONST_N_EPOCH):
        print >> sys.stderr, 'epoch', epoch
        print >> sys.stderr, 'learning rate', optimizer.lr
        perm = np.random.permutation(len(train_list))
        for idx in perm:
            path = train_list[idx]
            batch_pool[i] = pool.apply_async(read_image, args = (path, ),)
            i += 1

            if i == CONST_N_BATCH:
                for j, x in enumerate(batch_pool):
                    x_batch[j],y_batch[j] = x.get()
                data_q.put((x_batch.copy(), y_batch.copy()))
                i = 0

            count += 1
            if count % 1000 == 0:
                data_q.put('val')
                j = 0
                for path in val_list:
                    val_batch_pool[j] = pool.apply_async(read_image, args = (path, ),)
                    j += 1

                    if j == CONST_N_BATCH_EVAL:
                        for k, x in enumerate(val_batch_pool):
                            val_x_batch[k],val_y_batch[k] = x.get()
                        data_q.put((val_x_batch.copy(), val_y_batch.copy()))
                        j = 0
                data_q.put('train')

        optimizer.lr *= 0.97
    pool.close()
    pool.join()
    data_q.put('end')

# Logger
def log_result():
    train_count = 0
    train_cur_loss = 0
    begin_at = time.time()
    val_begin_at = None
    best_loss = np.Infinity
    while True:
        result = res_q.get()
        if result == 'end':
            print >> sys.stderr, ''
            break
        elif result == 'train':
            print >> sys.stderr, ''
            train = True
            if val_begin_at is not None:
                begin_at += time.time() - val_begin_at
                val_begin_at = None
            continue
        elif result == 'val':
            print >> sys.stderr, ''
            train = False
            val_count = val_loss = 0
            val_begin_at = time.time()
            continue

        loss, y,predict,tmp_model = result
        if train:
            train_count += 1
            duration     = time.time() - begin_at
            throughput   = train_count * CONST_N_BATCH / duration
            sys.stderr.write(
                '\rtrain {} updates ({} samples) time: {} ({} images/sec)'
                .format(train_count, train_count * CONST_N_BATCH,
                        timedelta(seconds=duration), throughput))

            train_cur_loss += loss
            if train_count % 20 == 0:
                y_tmp = y[0,:,:,:]
                pre_tmp = predict[0,:,:,:]
                eval_image_show(y=y_tmp,predict=pre_tmp,name="train_"+str(train_count)+"_")
                mean_loss  = train_cur_loss / 20
                print >> sys.stderr, ''
                print json.dumps({'type': 'train', 'iteration': train_count,'loss': mean_loss})
                sys.stdout.flush()
                train_cur_loss = 0
        else:
            val_count  += CONST_N_BATCH_EVAL
            duration    = time.time() - val_begin_at
            throughput  = val_count / duration
            sys.stderr.write(
                '\rval   {} batches ({} samples) time: {} ({} images/sec)'
                .format(val_count / CONST_N_BATCH_EVAL, val_count,
                        timedelta(seconds=duration), throughput))

            val_loss += loss
            if val_count == CONST_N_EVAL:
                y_tmp = y[0,:,:,:]
                pre_tmp = predict[0,:,:,:]
                eval_image_show(y=y_tmp,predict=pre_tmp,name="eval_"+str(train_count)+"_")
                mean_loss  = val_loss * CONST_N_BATCH_EVAL / CONST_N_EVAL
                if(best_loss > mean_loss):
                    filename_model = "model_" + str(mean_loss) + "_" + str(train_count)
                    pickle.dump(tmp_model, open(filename_model, 'wb'), -1)
                    best_loss = mean_loss
                print >> sys.stderr, ''
                print json.dumps({'type': 'val', 'iteration': train_count,'loss': mean_loss})
                sys.stdout.flush()

# Trainer
def train_loop():
    while True:
        while data_q.empty():
            time.sleep(0.1)
        inp = data_q.get()
        if inp == 'end':  # quit
            res_q.put('end')
            break
        elif inp == 'train':  # restart training
            res_q.put('train')
            train = True
            continue
        elif inp == 'val':  # start validation
            res_q.put('val')
            pickle.dump(model, open('model', 'wb'), -1)
            train = False
            continue

        x, y = inp
        if CONST_GPU_ID >= 0:
            x = cuda.to_gpu(x)
            y = cuda.to_gpu(y)

        if train:
            optimizer.zero_grads()
            loss, predict = model.forward(x, y)
            loss.backward()
            optimizer.update()
        else:
            loss, predict = model.forward(x, y, train=False)
        tmp_model = copy.deepcopy(model)
        tmp_model.to_cpu()
        res_q.put((float(cuda.to_cpu(loss.data)),
                   cuda.to_cpu(y),
                   cuda.to_cpu(predict.data),
                   tmp_model))
        del loss, predict, x, y

# Invoke threads
feeder = Thread(target=feed_data)
feeder.daemon = True
feeder.start()
logger = Thread(target=log_result)
logger.daemon = True
logger.start()

train_loop()
feeder.join()
logger.join()

# Save final model
pickle.dump(model, open('model', 'wb'), -1)