chainerでsin関数を学習させてみたをtrainerで書き直す

chainerのサンプルを探していると,古いものではtrainerを使っていないものも結構あります。
そこで,chainerでsin関数を学習させてみた
を書き直してみました。




# https://qiita.com/hikobotch/items/018808ef795061176824
# https://github.com/kose/chainer-linear-regression/blob/master/train.py

# とりあえず片っ端からimport
import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, Variable, optimizers, serializers, utils
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
import time
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import os
# データ
def get_dataset(N):
    x = np.linspace(0, 2 * np.pi, N)
    y = np.sin(x)
    dataset=[]
    for xx,yy in zip(x,y):
        dataset.append((xx,yy))
    return dataset

# ニューラルネットワーク
class MyChain(Chain):
    def __init__(self, n_units=10):
        super(MyChain, self).__init__(
             l1=L.Linear(1, n_units),
             l2=L.Linear(n_units, n_units),
             l3=L.Linear(n_units, 1))

    def __call__(self, x_data, y_data):
        x = Variable(x_data.astype(np.float32).reshape(len(x_data),1)) # Variableオブジェクトに変換
        y = Variable(y_data.astype(np.float32).reshape(len(y_data),1)) # Variableオブジェクトに変換
        loss= F.mean_squared_error(self.predict(x), y)
        chainer.reporter.report({
            'loss': loss
        }, self)

        return loss

    def  predict(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        h3 = self.l3(h2)
        return h3

    def get_predata(self, x):
        return self.predict(Variable(x.astype(np.float32).reshape(len(x),1))).data

from chainer.training import trigger as trigger_module
class MyChartReport(chainer.training.extension.Extension):

    #trigger = (1, 'epoch')

    def __init__(self,model,N_test=900,trigger = (1, 'epoch')):
        self.N_test=N_test
        self.model=model
        self._trigger = trigger_module.get_trigger(trigger)

    def __call__(self,trainer):
        if self._trigger(trainer):
            with chainer.configuration.using_config('train', False):
                theta = np.linspace(0, 2 * np.pi, self.N_test)
                sin = np.sin(theta)
                test = self.model.get_predata(theta)
                plt.plot(theta, sin, label="sin")
                plt.plot(theta, test, label="test")
                plt.legend()
                plt.grid(True)
                plt.xlim(0, 2 * np.pi)
                plt.ylim(-1.2, 1.2)
                plt.title("sin")
                plt.xlabel("theta")
                plt.ylabel("amp")
                plt.savefig("fig_sin_epoch{}.png".format(trainer.updater.epoch)) 
                plt.clf()



# main
if __name__ == "__main__":

    # 学習データ
    N = 1000
    train = get_dataset(N)

    # テストデータ
    N_test = 900
    test = get_dataset(N_test)

    # 学習パラメータ
    batchsize = 10
    n_epoch = 500
    n_units = 100
    gpu=-1

    # モデル作成
    model = MyChain(n_units)
    optimizer = optimizers.Adam()
    optimizer.setup(model)

    train_iter = chainer.iterators.SerialIterator(train, batchsize)
    test_iter = chainer.iterators.SerialIterator(test, batchsize, repeat=False, shuffle=False)

    updater = chainer.training.updaters.StandardUpdater(train_iter, optimizer, device=gpu)
    trainer = chainer.training.Trainer(updater, (n_epoch, 'epoch'), out="snapshot")

    trainer.extend(chainer.training.extensions.Evaluator(test_iter, model, device=gpu))
    trainer.extend(chainer.training.extensions.dump_graph('main/loss'))

    # Take a snapshot for each specified epoch
    frequency = 10
    trainer.extend(chainer.training.extensions.snapshot(filename="snapshot_cureent"),trigger=(frequency, 'epoch'))


    # Write a log of evaluation statistics for each epoch
    #trainer.extend(chainer.training.extensions.LogReport())

    trainer.extend(MyChartReport(model,trigger=(10, "epoch")))

    trainer.extend(chainer.training.extensions.LogReport(trigger=(10, "epoch")))
    trainer.extend(chainer.training.extensions.PrintReport(
        ["epoch", "main/loss", "validation/main/loss", "elapsed_time"]))
    #trainer.extend(chainer.training.extensions.ProgressBar())
    trainer.extend(chainer.training.extensions.PlotReport(['main/loss', 'validation/main/loss'],trigger=(10, "epoch"),filename="loss.png"))
    trainer.run()

    model.to_cpu()
    chainer.serializers.save_npz("model.h5",model)


ポイントは,extensionを使って,エポックごとにsinの学習結果を画像で保存しているところです。
結構いい加減に作っているので,画像保存ディレクトリやファイル名などは外から与えるようにするといいでしょう。