deeplearning(chainer)で超解像やってみた
最近、waifu2xというソフトウェアが話題になっています。
ultraist.hatenablog.com
画像拡大後、補正をかけることにより輪郭をシャープに見せるほか、ノイズを除去等できるようです。
ConvolutionalNeuralNetを適用することで実現しているようで、参考にされた論文はこちら。↓
Chao Dong, Chen Change Loy, Kaiming He, Xiaoou Tang, "Image Super-Resolution Using Deep Convolutional Networks"
http://arxiv.org/abs/1501.00092
Deepとは言いつつも、CNN3層とネットワーク構造が簡単で、画像が小さければCPUでも計算できそう?
ちょうどchainerを使ってみたいという気持ちがあったので、練習がてら簡単に実装してみました。
Deep Learning のフレームワーク Chainer を公開しました | Preferred Research
[学習について]
データのカテゴリについて:
画像のカテゴリによってネットワーク構造が変化すると考えられるため、カテゴリを統一しました。
今回は「森」と考えられる画像を適用します。(主観で)
画像サイズについて:
64✕64の画像を適用します。256✕256についても動かしてみましたが、CPUではウンとも動かなかったため断念。
画像データから、64✕64のROIを取得し、これを学習データとしました。
入出力データについて:
ネットワークへの入力には低画質化した画像を、出力にはオリジナルの画像を適用します。
低画質化の方法ですが、とりあえずscikit-imageパッケージのrescale関数を用いて
0.5倍→2.0倍とすることで低画質画像を作成しました。
色空間について:
RGB色空間のまま、補正ネットワークを学習します。
論文ではYCbCr空間に変換し、Yのみをネットワークにより補正しているようですが、今回はとりあえず雰囲気を見るということで。
[評価]
検証用画像を用意して、オリジナル・低画質化画像・補正結果を比較します。
オリジナル
低画質化画像
train epoch : 200
train epoch : 800
train epoch : 78000
epoch数200では学習が不十分で、入力よりも更に低画質化されています。
epoch数が増えるにつれ高画質化され、epoch数78000では入力画像よりも画質が向上しているように見えます。
(定量的な評価はしていません、、目視です)
画像が小さくてよくわからなかったので、大きな画像を64✕64のROIに分割し高画質化しました。
なんとなく画質が向上しているようです。
[所感]
SRCNN(SuperResolutionConvolutionalNeuralNet)について
効果を何となく確認することが出来ました。
やはり学習データが重要で、高解像度化したい対象に合わせてデータセットを準備しなければならないという印象。(waifu2xと同様に、今回学習したモデルをイラストに適用したところ、とても残念なことに)
低画質化のプロセスも重要で、除去したい要因(拡大、ピンぼけ、収差、etc)を考慮して、低画質化画像を作成する必要がありそう。
chianerについて
今までpylearn2しか使ったことがありませんが、chainerの方が便利に感じました。
データ構造やライブラリが理解しやすく、任意のネットワーク構築も比較的簡単にできそう。
今回も、なんとなく実装しましたが動きました。
DeepLearning実装の敷居が下がったのではないかと感じます。
[ソース]
(chainerのexampleを参考にしました。chainer/train_imagenet.py at master · chainer/chainer · GitHub)
(ザクっと作ったので、間違いがありましたらすみません。)
import cPickle as pickle from datetime import timedelta import json from multiprocessing import Pool from Queue import Queue import random import sys from threading import Thread import time import skimage.io import skimage.transform import numpy as np import copy import csv from chainer import cuda, Variable, FunctionSet, optimizers import chainer.functions as F class SRCNN(FunctionSet): insize = 64 outputsize = 64-(9-1)-(1-1)-(5-1) def __init__(self): super(SRCNN, self).__init__( conv1=F.Convolution2D(3,64,9), conv2=F.Convolution2D(64,32,1), conv3=F.Convolution2D(32,3,5), ) def forward(self, x_data, y_data, train=True): x = Variable(x_data, volatile=not train) t = Variable(y_data, volatile=not train) h = F.relu(self.conv1(x)) h = F.relu(self.conv2(h)) h = F.relu(self.conv3(h)) loss = F.mean_squared_error(h, t) return loss, h #input setting CONST_IMG_SIZE = 64 CONST_N_TRAIN = 10000 CONST_N_EVAL = 1000 CONST_PATH_TRAIN = "./output/train.csv" CONST_PATH_EVAL = "./output/eval.csv" CONST_N_BATCH = 50 CONST_N_BATCH_EVAL = 250 CONST_N_EPOCH = 1000 CONST_GPU_ID = -1 #cpu:-1,gpu:0~ CONST_N_LOADER = 1 #initial setting #craete model instance model = SRCNN() if CONST_GPU_ID >= 0: cuda.init(CONST_GPU_ID) model.to_gpu() # Setup optimizer optimizer = optimizers.MomentumSGD(lr=0.01, momentum=0.9) optimizer.setup(model.collect_parameters()) # Prepare dataset def load_image_list(path): pathlist = [] csvf = open(path) reader = csv.reader(csvf) for line in reader: pathlist.append(line[0]) csvf.close() return pathlist train_list = load_image_list(CONST_PATH_TRAIN) val_list = load_image_list(CONST_PATH_EVAL) # ------------------------------------------------------------------------------ # This example consists of three threads: data feeder, logger and trainer. These # communicate with each other via Queue. data_q = Queue(maxsize=1) res_q = Queue() # Data loading routine def read_image(path): image = skimage.io.imread(path,as_gray=False) v_rand = random.randint(0,3) if v_rand == 0: image = image[:,:,:] elif v_rand == 1: image = image[::-1,:,:] elif v_rand == 2: image = image[:,::-1,:] elif v_rand == 3: image = image[::-1,::-1,:] #create output data output_top = (CONST_IMG_SIZE- model.outputsize) / 2 output_bottom = CONST_IMG_SIZE- output_top image_out = image[output_top:output_bottom,output_top:output_bottom,:] image_out = image_out.transpose(2, 0, 1).astype(np.float32)/255. #create input coarse data image_in = skimage.transform.rescale(image,0.5) image_in = skimage.transform.rescale(image_in,2.)#/255 already 0~1 image_in = image_in.transpose(2, 0, 1).astype(np.float32) return image_in,image_out def eval_image_show(y,predict,name): y_convert = y.transpose(1, 2, 0) predict_convert = predict.transpose(1, 2, 0) #create coarse data x_convert = skimage.transform.rescale(y_convert,0.5) x_convert = skimage.transform.rescale(x_convert,2.) #convert for output y_convert[np.where(y_convert>1.0)] = 1.0 x_convert[np.where(x_convert>1.0)] = 1.0 predict_convert[np.where(predict_convert>1.0)] = 1.0 y_convert = (y_convert*255).astype(np.uint8) x_convert = (x_convert*255).astype(np.uint8) predict_convert = (predict_convert*255).astype(np.uint8) skimage.io.imsave(name+"in.png",x_convert) skimage.io.imsave(name+"out.png",y_convert) skimage.io.imsave(name+"pre.png",predict_convert) # Data feeder def feed_data(): i = 0 count = 0 x_batch = np.ndarray((CONST_N_BATCH, 3, CONST_IMG_SIZE, CONST_IMG_SIZE), dtype=np.float32) y_batch = np.ndarray((CONST_N_BATCH,3, model.outputsize, model.outputsize), dtype=np.float32) val_x_batch = np.ndarray((CONST_N_BATCH_EVAL, 3, CONST_IMG_SIZE, CONST_IMG_SIZE), dtype=np.float32) val_y_batch = np.ndarray((CONST_N_BATCH_EVAL,3, model.outputsize, model.outputsize), dtype=np.float32) batch_pool = [None] * CONST_N_BATCH val_batch_pool = [None] * CONST_N_BATCH_EVAL pool = Pool(CONST_N_LOADER) data_q.put('train') for epoch in xrange(1, 1 + CONST_N_EPOCH): print >> sys.stderr, 'epoch', epoch print >> sys.stderr, 'learning rate', optimizer.lr perm = np.random.permutation(len(train_list)) for idx in perm: path = train_list[idx] batch_pool[i] = pool.apply_async(read_image, args = (path, ),) i += 1 if i == CONST_N_BATCH: for j, x in enumerate(batch_pool): x_batch[j],y_batch[j] = x.get() data_q.put((x_batch.copy(), y_batch.copy())) i = 0 count += 1 if count % 1000 == 0: data_q.put('val') j = 0 for path in val_list: val_batch_pool[j] = pool.apply_async(read_image, args = (path, ),) j += 1 if j == CONST_N_BATCH_EVAL: for k, x in enumerate(val_batch_pool): val_x_batch[k],val_y_batch[k] = x.get() data_q.put((val_x_batch.copy(), val_y_batch.copy())) j = 0 data_q.put('train') optimizer.lr *= 0.97 pool.close() pool.join() data_q.put('end') # Logger def log_result(): train_count = 0 train_cur_loss = 0 begin_at = time.time() val_begin_at = None best_loss = np.Infinity while True: result = res_q.get() if result == 'end': print >> sys.stderr, '' break elif result == 'train': print >> sys.stderr, '' train = True if val_begin_at is not None: begin_at += time.time() - val_begin_at val_begin_at = None continue elif result == 'val': print >> sys.stderr, '' train = False val_count = val_loss = 0 val_begin_at = time.time() continue loss, y,predict,tmp_model = result if train: train_count += 1 duration = time.time() - begin_at throughput = train_count * CONST_N_BATCH / duration sys.stderr.write( '\rtrain {} updates ({} samples) time: {} ({} images/sec)' .format(train_count, train_count * CONST_N_BATCH, timedelta(seconds=duration), throughput)) train_cur_loss += loss if train_count % 20 == 0: y_tmp = y[0,:,:,:] pre_tmp = predict[0,:,:,:] eval_image_show(y=y_tmp,predict=pre_tmp,name="train_"+str(train_count)+"_") mean_loss = train_cur_loss / 20 print >> sys.stderr, '' print json.dumps({'type': 'train', 'iteration': train_count,'loss': mean_loss}) sys.stdout.flush() train_cur_loss = 0 else: val_count += CONST_N_BATCH_EVAL duration = time.time() - val_begin_at throughput = val_count / duration sys.stderr.write( '\rval {} batches ({} samples) time: {} ({} images/sec)' .format(val_count / CONST_N_BATCH_EVAL, val_count, timedelta(seconds=duration), throughput)) val_loss += loss if val_count == CONST_N_EVAL: y_tmp = y[0,:,:,:] pre_tmp = predict[0,:,:,:] eval_image_show(y=y_tmp,predict=pre_tmp,name="eval_"+str(train_count)+"_") mean_loss = val_loss * CONST_N_BATCH_EVAL / CONST_N_EVAL if(best_loss > mean_loss): filename_model = "model_" + str(mean_loss) + "_" + str(train_count) pickle.dump(tmp_model, open(filename_model, 'wb'), -1) best_loss = mean_loss print >> sys.stderr, '' print json.dumps({'type': 'val', 'iteration': train_count,'loss': mean_loss}) sys.stdout.flush() # Trainer def train_loop(): while True: while data_q.empty(): time.sleep(0.1) inp = data_q.get() if inp == 'end': # quit res_q.put('end') break elif inp == 'train': # restart training res_q.put('train') train = True continue elif inp == 'val': # start validation res_q.put('val') pickle.dump(model, open('model', 'wb'), -1) train = False continue x, y = inp if CONST_GPU_ID >= 0: x = cuda.to_gpu(x) y = cuda.to_gpu(y) if train: optimizer.zero_grads() loss, predict = model.forward(x, y) loss.backward() optimizer.update() else: loss, predict = model.forward(x, y, train=False) tmp_model = copy.deepcopy(model) tmp_model.to_cpu() res_q.put((float(cuda.to_cpu(loss.data)), cuda.to_cpu(y), cuda.to_cpu(predict.data), tmp_model)) del loss, predict, x, y # Invoke threads feeder = Thread(target=feed_data) feeder.daemon = True feeder.start() logger = Thread(target=log_result) logger.daemon = True logger.start() train_loop() feeder.join() logger.join() # Save final model pickle.dump(model, open('model', 'wb'), -1)