chainerのサンプルを探していると,古いものではtrainerを使っていないものも結構あります。
そこで,chainerでsin関数を学習させてみた
を書き直してみました。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | # https://qiita.com/hikobotch/items/018808ef795061176824 # https://github.com/kose/chainer-linear-regression/blob/master/train.py # とりあえず片っ端からimport import numpy as np import chainer from chainer import cuda, Function, gradient_check, Variable, optimizers, serializers, utils from chainer import Link, Chain, ChainList import chainer.functions as F import chainer.links as L import time import matplotlib matplotlib.use('Agg') from matplotlib import pyplot as plt import os # データ def get_dataset(N): x = np.linspace(0, 2 * np.pi, N) y = np.sin(x) dataset=[] for xx,yy in zip(x,y): dataset.append((xx,yy)) return dataset # ニューラルネットワーク class MyChain(Chain): def __init__(self, n_units=10): super(MyChain, self).__init__( l1=L.Linear(1, n_units), l2=L.Linear(n_units, n_units), l3=L.Linear(n_units, 1)) def __call__(self, x_data, y_data): x = Variable(x_data.astype(np.float32).reshape(len(x_data),1)) # Variableオブジェクトに変換 y = Variable(y_data.astype(np.float32).reshape(len(y_data),1)) # Variableオブジェクトに変換 loss= F.mean_squared_error(self.predict(x), y) chainer.reporter.report({ 'loss': loss }, self) return loss def predict(self, x): h1 = F.relu(self.l1(x)) h2 = F.relu(self.l2(h1)) h3 = self.l3(h2) return h3 def get_predata(self, x): return self.predict(Variable(x.astype(np.float32).reshape(len(x),1))).data from chainer.training import trigger as trigger_module class MyChartReport(chainer.training.extension.Extension): #trigger = (1, 'epoch') def __init__(self,model,N_test=900,trigger = (1, 'epoch')): self.N_test=N_test self.model=model self._trigger = trigger_module.get_trigger(trigger) def __call__(self,trainer): if self._trigger(trainer): with chainer.configuration.using_config('train', False): theta = np.linspace(0, 2 * np.pi, self.N_test) sin = np.sin(theta) test = self.model.get_predata(theta) plt.plot(theta, sin, label="sin") plt.plot(theta, test, label="test") plt.legend() plt.grid(True) plt.xlim(0, 2 * np.pi) plt.ylim(-1.2, 1.2) plt.title("sin") plt.xlabel("theta") plt.ylabel("amp") plt.savefig("fig_sin_epoch{}.png".format(trainer.updater.epoch)) plt.clf() # main if __name__ == "__main__": # 学習データ N = 1000 train = get_dataset(N) # テストデータ N_test = 900 test = get_dataset(N_test) # 学習パラメータ batchsize = 10 n_epoch = 500 n_units = 100 gpu=-1 # モデル作成 model = MyChain(n_units) optimizer = optimizers.Adam() optimizer.setup(model) train_iter = chainer.iterators.SerialIterator(train, batchsize) test_iter = chainer.iterators.SerialIterator(test, batchsize, repeat=False, shuffle=False) updater = chainer.training.updaters.StandardUpdater(train_iter, optimizer, device=gpu) trainer = chainer.training.Trainer(updater, (n_epoch, 'epoch'), out="snapshot") trainer.extend(chainer.training.extensions.Evaluator(test_iter, model, device=gpu)) trainer.extend(chainer.training.extensions.dump_graph('main/loss')) # Take a snapshot for each specified epoch frequency = 10 trainer.extend(chainer.training.extensions.snapshot(filename="snapshot_cureent"),trigger=(frequency, 'epoch')) # Write a log of evaluation statistics for each epoch #trainer.extend(chainer.training.extensions.LogReport()) trainer.extend(MyChartReport(model,trigger=(10, "epoch"))) trainer.extend(chainer.training.extensions.LogReport(trigger=(10, "epoch"))) trainer.extend(chainer.training.extensions.PrintReport( ["epoch", "main/loss", "validation/main/loss", "elapsed_time"])) #trainer.extend(chainer.training.extensions.ProgressBar()) trainer.extend(chainer.training.extensions.PlotReport(['main/loss', 'validation/main/loss'],trigger=(10, "epoch"),filename="loss.png")) trainer.run() model.to_cpu() chainer.serializers.save_npz("model.h5",model) |
ポイントは,extensionを使って,エポックごとにsinの学習結果を画像で保存しているところです。
結構いい加減に作っているので,画像保存ディレクトリやファイル名などは外から与えるようにするといいでしょう。