こちら
で紹介しているSeq2Seqを動く形で作成しなおしてみました。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | train_pair = [ ["初めまして。", "初めまして。よろしくお願いします。"], ["どこから来たんですか?", "日本から来ました。"], ["日本のどこに住んでるんですか?", "東京に住んでいます。"], ["仕事は何してますか?", "私は会社員です。"], ["お会いできて嬉しかったです。", "私もです!"], ["おはよう。", "おはようございます。"], ["いつも何時に起きますか?", "6時に起きます。"], ["朝食は何を食べますか?", "たいていトーストと卵を食べます。"], ["朝食は毎日食べますか?", "たまに朝食を抜くことがあります。"], ["野菜をたくさん取っていますか?", "毎日野菜を取るようにしています。"], ["週末は何をしていますか?", "友達と会っていることが多いです。"], ["どこに行くのが好き?", "私たちは渋谷に行くのが好きです。"] ] test_pair = [ ["初めまして。", "初めまして。よろしくお願いします。"], ["どこから来たんですか?", "米国から来ました。"], ["米国のどこに住んでるんですか?", "ニューヨークに住んでいます。"], ["おはよう。", "おはよう。"], ["いつも何時に起きますか?", "7時に起きます。"], ["夕食は何を食べますか?", "たいていトーストと卵を食べます。"], ["夕食は毎日食べますか?", "たまに朝食を抜くことがあります。"], ["肉をたくさん取っていますか?", "毎日インクを取るようにしています。"], ["週頭は何をしていますか?", "友達と会っていることが多いです。"], ] # https://nojima.hatenablog.com/entry/2017/10/10/023147 import nltk import MeCab import chainer from chainer.datasets import split_dataset_random , cifar import chainer.links as L import chainer.functions as F import random import numpy as np import codecs SIZE=10000 EOS=1 UNK=0 class EncoderDecoder(chainer.Chain): def __init__(self, n_vocab, n_out, n_hidden): super(EncoderDecoder,self).__init__() with self.init_scope(): self.embed_x = L.EmbedID(n_vocab, n_hidden) self.embed_y = L.EmbedID(n_out,n_hidden) self.encoder = L.NStepLSTM( n_layers=1, in_size=n_hidden, out_size=n_hidden, dropout=0.1) self.decoder = L.NStepLSTM( n_layers=1, in_size=n_hidden, out_size=n_hidden, dropout=0.1) self.W = L.Linear(n_hidden, n_out) def __call__(self, xs , ys ): xs = [x[::-1] for x in xs] eos = self.xp.array([EOS], dtype=np.int32) ys_in = [F.concat((eos, y), axis=0) for y in ys] ys_out = [F.concat((y, eos), axis=0) for y in ys] # Both xs and ys_in are lists of arrays. exs = [self.embed_x(x) for x in xs] eys = [self.embed_y(y) for y in ys_in] hx, cx, _ = self.encoder(None, None, exs) _, _, os = self.decoder(hx, cx, eys) batch = len(xs) concat_os = F.concat(os, axis=0) concat_ys_out = F.concat(ys_out, axis=0) loss = F.sum(F.softmax_cross_entropy(self.W(concat_os), concat_ys_out, reduce='no')) / batch chainer.report({'loss': loss}, self) n_words = concat_ys_out.shape[0] perp = self.xp.exp(loss.array * batch / n_words) chainer.report({'perp': perp}, self) return loss def translate(self,xs,max_length=30): with chainer.no_backprop_mode(),chainer.using_config("train",False): xs=xs[::-1] # reverse list #exs = [self.embed_x(x) for x in xs] exs = self.embed_x(xs) hx, cx, _ = self.encoder(None, None, [exs]) result=[] word_id=EOS for i in range(max_length): os,cs=self._translate_word(word_id,hx,cx) word_id=int(F.argmax(os).data) if word_id==EOS:break result.append(word_id) return result def _translate_word(self,word_id,hx,cx): y=np.array([word_id],dtype=np.int32) ey=self.embed_y(y) _, cs, os = self.decoder(hx, cx, [ey]) fos=F.softmax(self.W(os[0])) return fos,cs class Data(chainer.dataset.DatasetMixin): def __init__(self): mecab = MeCab.Tagger("-Owakati") self.vocab={"<eos>":0,"<unk>":1} def to_dataset(source,target,train=True): swords = to_number(mecab.parse(source).strip().split(" "),train) twords = to_number(mecab.parse(target).strip().split(" "),train) return (np.array(swords).astype(np.int32),np.array(twords).astype(np.int32)) def to_number(words,train): ds=[] for w in words: if w not in self.vocab: if train: self.vocab[w]=len(self.vocab) else: w="<unk>" ds.append(self.vocab[w]) return ds self.train_data=[] self.test_data=[] for source,target in train_pair: self.train_data.append(to_dataset(source,target)) for source,target in test_pair: self.test_data.append(to_dataset(source,target,False)) self.vocab_inv={} for w in self.vocab.keys(): self.vocab_inv[self.vocab[w]]=w def convert(batch, device): def to_device_batch(batch): return [chainer.dataset.to_device(device, x) for x in batch] res= {'xs': to_device_batch([x for x, _ in batch]), 'ys': to_device_batch([y for _, y in batch])} return res seed = 12345 random.seed(seed) np.random.seed(seed) data = Data() batchsize=128 #train_iter = chainer.iterators.MultithreadIterator(train,batchsize,n_threads=4) #test_iter=chainer.iterators.MultithreadIterator(test,len(test),repeat=False,shuffle=False,n_threads=4) train_iter = chainer.iterators.SerialIterator(data.train_data,batchsize) test_iter=chainer.iterators.SerialIterator(data.test_data,len(data.test_data)) n_vocab=len(data.vocab) n_out=len(data.vocab) n_hidden=300 print("n_vocab:",n_vocab) optimizer=chainer.optimizers.Adam() mlp=EncoderDecoder(n_vocab,n_out,n_hidden) optimizer.setup(mlp) updater=chainer.training.StandardUpdater(train_iter,optimizer,converter=convert,device=-1) #train epochs=1000 trainer=chainer.training.Trainer(updater,(epochs,"epoch"),out="dialog_result") trainer.extend(chainer.training.extensions.LogReport()) trainer.extend(chainer.training.extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy'])) trainer.run() mlp.to_cpu() chainer.serializers.save_npz("dialog.model",mlp) # https://nojima.hatenablog.com/entry/2017/10/17/034840 mlp=EncoderDecoder(n_vocab,n_out,n_hidden) chainer.serializers.load_npz("dialog.model",mlp,path="") for source,target in data.test_data: predict=mlp.translate(np.array(source)) print("source:",[data.vocab_inv[w] for w in source]) print("predict:",[data.vocab_inv[w] for w in predict]) print("target:",[data.vocab_inv[w] for w in target]) |
結果
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | n_vocab: 71 epoch main/loss main/accuracy 1 34.5502 2 31.7847 3 29.1603 4 26.6713 5 24.3079 6 22.0612 7 19.9224 8 17.8789 9 15.919 10 14.0387 11 12.2457 12 10.5555 13 8.98597 14 7.55335 15 6.26964 16 5.14105 17 4.16823 18 3.34746 19 2.67084 20 2.12592 21 1.69575 22 1.36043 23 1.10025 24 0.898079 25 0.740202 26 0.616168 27 0.518174 28 0.44032 29 0.378046 30 0.327818 31 0.286931 32 0.253339 33 0.225496 34 0.202227 35 0.18263 36 0.166004 37 0.151796 38 0.139575 39 0.128993 40 0.119775 41 0.1117 42 0.104586 43 0.09829 44 0.0926925 45 0.0876961 46 0.0832203 47 0.0791977 48 0.0755707 49 0.0722914 50 0.0693159 51 0.0666087 52 0.0641389 53 0.0618778 54 0.0598021 55 0.0578909 56 0.0561267 57 0.0544931 58 0.0529775 59 0.051567 60 0.0502504 61 0.0490188 62 0.0478641 63 0.0467791 64 0.0457571 65 0.0447925 66 0.0438801 67 0.0430147 68 0.0421937 69 0.0414122 70 0.0406681 71 0.039957 72 0.0392778 73 0.0386279 74 0.0380038 75 0.037405 76 0.0368287 77 0.0362744 78 0.0357389 79 0.0352229 80 0.0347236 81 0.0342409 82 0.0337726 83 0.0333189 84 0.0328779 85 0.0324506 86 0.0320346 87 0.0316298 88 0.0312351 89 0.0308511 90 0.0304759 91 0.0301098 92 0.0297521 93 0.0294023 94 0.0290616 source: ['初め', 'まして', '。'] predict: ['初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。'] target: ['初め', 'まして', '。', 'よろしくお願いします', '。'] source: ['どこ', 'から', '来', 'た', 'ん', 'です', 'か', '?'] predict: ['日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来'] target: ['<unk>', 'から', '来', 'まし', 'た', '。'] source: ['<unk>', 'の', 'どこ', 'に', '住ん', 'でる', 'ん', 'です', 'か', '?'] predict: ['東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん'] target: ['<unk>', 'に', '住ん', 'で', 'い', 'ます', '。'] source: ['おはよう。'] predict: ['おはよう', 'ござい', 'ます', '。'] target: ['おはよう。'] source: ['いつも', '何', '時', 'に', '起き', 'ます', 'か', '?'] predict: ['6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に'] target: ['<unk>', '時', 'に', '起き', 'ます', '。'] source: ['<unk>', 'は', '何', 'を', '食べ', 'ます', 'か', '?'] predict: ['たいてい', 'トースト', 'と', '卵', 'を', '抜く', 'たいてい', 'トースト', 'と', '卵', 'を', '抜く', 'たいてい', 'トースト', 'と', '卵', 'を', '抜く', 'たいてい', 'トースト', 'と', '卵', 'を', '抜く', 'たいてい', 'トースト', 'と', '卵', 'を', '抜く'] target: ['たいてい', 'トースト', 'と', '卵', 'を', '食べ', 'ます', '。'] source: ['<unk>', 'は', '毎日', '食べ', 'ます', 'か', '?'] predict: ['たまに', '朝食', 'を', '抜く', 'たまに', '朝食', 'を', '抜く', 'たまに', '朝食', 'を', '抜く', 'たまに', '朝食', 'を', '抜く', 'たまに', '朝食', 'を', '抜く', 'たまに', '朝食', 'を', '抜く', 'たまに', '朝食', 'を', '抜く', 'たまに', '朝食'] target: ['たまに', '朝食', 'を', '抜く', 'こと', 'が', 'あり', 'ます', '。'] source: ['<unk>', 'を', 'たくさん', '取っ', 'て', 'い', 'ます', 'か', '?'] predict: ['毎日', '野菜', 'を', '取る', 'よう', 'に', '毎日', '野菜', 'を', '取る', 'よう', 'に', '毎日', '野菜', 'を', '取る', 'よう', 'に', '毎日', '野菜', 'を', '取る', 'よう', 'に', '毎日', '野菜', 'を', '取る', 'よう', 'に'] target: ['毎日', '<unk>', 'を', '取る', 'よう', 'に', 'し', 'て', 'い', 'ます', '。'] source: ['<unk>', '<unk>', 'は', '何', 'を', 'し', 'て', 'い', 'ます', 'か', '?'] predict: ['友達', 'と', '会っ', 'て', 'いる', 'と', '会っ', 'て', 'いる', 'と', '会っ', 'て', 'いる', 'と', '会っ', 'て', 'いる', 'と', '会っ', 'て', 'いる', 'と', '会っ', 'て', 'いる', 'と', '会っ', 'て', 'いる', 'と'] target: ['友達', 'と', '会っ', 'て', 'いる', 'こと', 'が', '多い', 'です', '。'] |