chainerのサンプルを探していると,古いものではtrainerを使っていないものも結構あります。
そこで,chainerでsin関数を学習させてみた
を書き直してみました。
# https://qiita.com/hikobotch/items/018808ef795061176824
# https://github.com/kose/chainer-linear-regression/blob/master/train.py
# とりあえず片っ端からimport
import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, Variable, optimizers, serializers, utils
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
import time
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import os
# データ
def get_dataset(N):
x = np.linspace(0, 2 * np.pi, N)
y = np.sin(x)
dataset=[]
for xx,yy in zip(x,y):
dataset.append((xx,yy))
return dataset
# ニューラルネットワーク
class MyChain(Chain):
def __init__(self, n_units=10):
super(MyChain, self).__init__(
l1=L.Linear(1, n_units),
l2=L.Linear(n_units, n_units),
l3=L.Linear(n_units, 1))
def __call__(self, x_data, y_data):
x = Variable(x_data.astype(np.float32).reshape(len(x_data),1)) # Variableオブジェクトに変換
y = Variable(y_data.astype(np.float32).reshape(len(y_data),1)) # Variableオブジェクトに変換
loss= F.mean_squared_error(self.predict(x), y)
chainer.reporter.report({
'loss': loss
}, self)
return loss
def predict(self, x):
h1 = F.relu(self.l1(x))
h2 = F.relu(self.l2(h1))
h3 = self.l3(h2)
return h3
def get_predata(self, x):
return self.predict(Variable(x.astype(np.float32).reshape(len(x),1))).data
from chainer.training import trigger as trigger_module
class MyChartReport(chainer.training.extension.Extension):
#trigger = (1, 'epoch')
def __init__(self,model,N_test=900,trigger = (1, 'epoch')):
self.N_test=N_test
self.model=model
self._trigger = trigger_module.get_trigger(trigger)
def __call__(self,trainer):
if self._trigger(trainer):
with chainer.configuration.using_config('train', False):
theta = np.linspace(0, 2 * np.pi, self.N_test)
sin = np.sin(theta)
test = self.model.get_predata(theta)
plt.plot(theta, sin, label="sin")
plt.plot(theta, test, label="test")
plt.legend()
plt.grid(True)
plt.xlim(0, 2 * np.pi)
plt.ylim(-1.2, 1.2)
plt.title("sin")
plt.xlabel("theta")
plt.ylabel("amp")
plt.savefig("fig_sin_epoch{}.png".format(trainer.updater.epoch))
plt.clf()
# main
if __name__ == "__main__":
# 学習データ
N = 1000
train = get_dataset(N)
# テストデータ
N_test = 900
test = get_dataset(N_test)
# 学習パラメータ
batchsize = 10
n_epoch = 500
n_units = 100
gpu=-1
# モデル作成
model = MyChain(n_units)
optimizer = optimizers.Adam()
optimizer.setup(model)
train_iter = chainer.iterators.SerialIterator(train, batchsize)
test_iter = chainer.iterators.SerialIterator(test, batchsize, repeat=False, shuffle=False)
updater = chainer.training.updaters.StandardUpdater(train_iter, optimizer, device=gpu)
trainer = chainer.training.Trainer(updater, (n_epoch, 'epoch'), out="snapshot")
trainer.extend(chainer.training.extensions.Evaluator(test_iter, model, device=gpu))
trainer.extend(chainer.training.extensions.dump_graph('main/loss'))
# Take a snapshot for each specified epoch
frequency = 10
trainer.extend(chainer.training.extensions.snapshot(filename="snapshot_cureent"),trigger=(frequency, 'epoch'))
# Write a log of evaluation statistics for each epoch
#trainer.extend(chainer.training.extensions.LogReport())
trainer.extend(MyChartReport(model,trigger=(10, "epoch")))
trainer.extend(chainer.training.extensions.LogReport(trigger=(10, "epoch")))
trainer.extend(chainer.training.extensions.PrintReport(
["epoch", "main/loss", "validation/main/loss", "elapsed_time"]))
#trainer.extend(chainer.training.extensions.ProgressBar())
trainer.extend(chainer.training.extensions.PlotReport(['main/loss', 'validation/main/loss'],trigger=(10, "epoch"),filename="loss.png"))
trainer.run()
model.to_cpu()
chainer.serializers.save_npz("model.h5",model)
ポイントは,extensionを使って,エポックごとにsinの学習結果を画像で保存しているところです。
結構いい加減に作っているので,画像保存ディレクトリやファイル名などは外から与えるようにするといいでしょう。