こちら
で紹介しているSeq2Seqを動く形で作成しなおしてみました。
train_pair = [
["初めまして。", "初めまして。よろしくお願いします。"],
["どこから来たんですか?", "日本から来ました。"],
["日本のどこに住んでるんですか?", "東京に住んでいます。"],
["仕事は何してますか?", "私は会社員です。"],
["お会いできて嬉しかったです。", "私もです!"],
["おはよう。", "おはようございます。"],
["いつも何時に起きますか?", "6時に起きます。"],
["朝食は何を食べますか?", "たいていトーストと卵を食べます。"],
["朝食は毎日食べますか?", "たまに朝食を抜くことがあります。"],
["野菜をたくさん取っていますか?", "毎日野菜を取るようにしています。"],
["週末は何をしていますか?", "友達と会っていることが多いです。"],
["どこに行くのが好き?", "私たちは渋谷に行くのが好きです。"]
]
test_pair = [
["初めまして。", "初めまして。よろしくお願いします。"],
["どこから来たんですか?", "米国から来ました。"],
["米国のどこに住んでるんですか?", "ニューヨークに住んでいます。"],
["おはよう。", "おはよう。"],
["いつも何時に起きますか?", "7時に起きます。"],
["夕食は何を食べますか?", "たいていトーストと卵を食べます。"],
["夕食は毎日食べますか?", "たまに朝食を抜くことがあります。"],
["肉をたくさん取っていますか?", "毎日インクを取るようにしています。"],
["週頭は何をしていますか?", "友達と会っていることが多いです。"],
]
# https://nojima.hatenablog.com/entry/2017/10/10/023147
import nltk
import MeCab
import chainer
from chainer.datasets import split_dataset_random , cifar
import chainer.links as L
import chainer.functions as F
import random
import numpy as np
import codecs
SIZE=10000
EOS=1
UNK=0
class EncoderDecoder(chainer.Chain):
def __init__(self, n_vocab, n_out, n_hidden):
super(EncoderDecoder,self).__init__()
with self.init_scope():
self.embed_x = L.EmbedID(n_vocab, n_hidden)
self.embed_y = L.EmbedID(n_out,n_hidden)
self.encoder = L.NStepLSTM(
n_layers=1,
in_size=n_hidden,
out_size=n_hidden,
dropout=0.1)
self.decoder = L.NStepLSTM(
n_layers=1,
in_size=n_hidden,
out_size=n_hidden,
dropout=0.1)
self.W = L.Linear(n_hidden, n_out)
def __call__(self, xs , ys ):
xs = [x[::-1] for x in xs]
eos = self.xp.array([EOS], dtype=np.int32)
ys_in = [F.concat((eos, y), axis=0) for y in ys]
ys_out = [F.concat((y, eos), axis=0) for y in ys]
# Both xs and ys_in are lists of arrays.
exs = [self.embed_x(x) for x in xs]
eys = [self.embed_y(y) for y in ys_in]
hx, cx, _ = self.encoder(None, None, exs)
_, _, os = self.decoder(hx, cx, eys)
batch = len(xs)
concat_os = F.concat(os, axis=0)
concat_ys_out = F.concat(ys_out, axis=0)
loss = F.sum(F.softmax_cross_entropy(self.W(concat_os), concat_ys_out, reduce='no')) / batch
chainer.report({'loss': loss}, self)
n_words = concat_ys_out.shape[0]
perp = self.xp.exp(loss.array * batch / n_words)
chainer.report({'perp': perp}, self)
return loss
def translate(self,xs,max_length=30):
with chainer.no_backprop_mode(),chainer.using_config("train",False):
xs=xs[::-1] # reverse list
#exs = [self.embed_x(x) for x in xs]
exs = self.embed_x(xs)
hx, cx, _ = self.encoder(None, None, [exs])
result=[]
word_id=EOS
for i in range(max_length):
os,cs=self._translate_word(word_id,hx,cx)
word_id=int(F.argmax(os).data)
if word_id==EOS:break
result.append(word_id)
return result
def _translate_word(self,word_id,hx,cx):
y=np.array([word_id],dtype=np.int32)
ey=self.embed_y(y)
_, cs, os = self.decoder(hx, cx, [ey])
fos=F.softmax(self.W(os[0]))
return fos,cs
class Data(chainer.dataset.DatasetMixin):
def __init__(self):
mecab = MeCab.Tagger("-Owakati")
self.vocab={"":0,"":1}
def to_dataset(source,target,train=True):
swords = to_number(mecab.parse(source).strip().split(" "),train)
twords = to_number(mecab.parse(target).strip().split(" "),train)
return (np.array(swords).astype(np.int32),np.array(twords).astype(np.int32))
def to_number(words,train):
ds=[]
for w in words:
if w not in self.vocab:
if train:
self.vocab[w]=len(self.vocab)
else:
w=""
ds.append(self.vocab[w])
return ds
self.train_data=[]
self.test_data=[]
for source,target in train_pair:
self.train_data.append(to_dataset(source,target))
for source,target in test_pair:
self.test_data.append(to_dataset(source,target,False))
self.vocab_inv={}
for w in self.vocab.keys():
self.vocab_inv[self.vocab[w]]=w
def convert(batch, device):
def to_device_batch(batch):
return [chainer.dataset.to_device(device, x) for x in batch]
res= {'xs': to_device_batch([x for x, _ in batch]),
'ys': to_device_batch([y for _, y in batch])}
return res
seed = 12345
random.seed(seed)
np.random.seed(seed)
data = Data()
batchsize=128
#train_iter = chainer.iterators.MultithreadIterator(train,batchsize,n_threads=4)
#test_iter=chainer.iterators.MultithreadIterator(test,len(test),repeat=False,shuffle=False,n_threads=4)
train_iter = chainer.iterators.SerialIterator(data.train_data,batchsize)
test_iter=chainer.iterators.SerialIterator(data.test_data,len(data.test_data))
n_vocab=len(data.vocab)
n_out=len(data.vocab)
n_hidden=300
print("n_vocab:",n_vocab)
optimizer=chainer.optimizers.Adam()
mlp=EncoderDecoder(n_vocab,n_out,n_hidden)
optimizer.setup(mlp)
updater=chainer.training.StandardUpdater(train_iter,optimizer,converter=convert,device=-1)
#train
epochs=1000
trainer=chainer.training.Trainer(updater,(epochs,"epoch"),out="dialog_result")
trainer.extend(chainer.training.extensions.LogReport())
trainer.extend(chainer.training.extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy']))
trainer.run()
mlp.to_cpu()
chainer.serializers.save_npz("dialog.model",mlp)
# https://nojima.hatenablog.com/entry/2017/10/17/034840
mlp=EncoderDecoder(n_vocab,n_out,n_hidden)
chainer.serializers.load_npz("dialog.model",mlp,path="")
for source,target in data.test_data:
predict=mlp.translate(np.array(source))
print("source:",[data.vocab_inv[w] for w in source])
print("predict:",[data.vocab_inv[w] for w in predict])
print("target:",[data.vocab_inv[w] for w in target])
結果
n_vocab: 71
epoch main/loss main/accuracy
1 34.5502
2 31.7847
3 29.1603
4 26.6713
5 24.3079
6 22.0612
7 19.9224
8 17.8789
9 15.919
10 14.0387
11 12.2457
12 10.5555
13 8.98597
14 7.55335
15 6.26964
16 5.14105
17 4.16823
18 3.34746
19 2.67084
20 2.12592
21 1.69575
22 1.36043
23 1.10025
24 0.898079
25 0.740202
26 0.616168
27 0.518174
28 0.44032
29 0.378046
30 0.327818
31 0.286931
32 0.253339
33 0.225496
34 0.202227
35 0.18263
36 0.166004
37 0.151796
38 0.139575
39 0.128993
40 0.119775
41 0.1117
42 0.104586
43 0.09829
44 0.0926925
45 0.0876961
46 0.0832203
47 0.0791977
48 0.0755707
49 0.0722914
50 0.0693159
51 0.0666087
52 0.0641389
53 0.0618778
54 0.0598021
55 0.0578909
56 0.0561267
57 0.0544931
58 0.0529775
59 0.051567
60 0.0502504
61 0.0490188
62 0.0478641
63 0.0467791
64 0.0457571
65 0.0447925
66 0.0438801
67 0.0430147
68 0.0421937
69 0.0414122
70 0.0406681
71 0.039957
72 0.0392778
73 0.0386279
74 0.0380038
75 0.037405
76 0.0368287
77 0.0362744
78 0.0357389
79 0.0352229
80 0.0347236
81 0.0342409
82 0.0337726
83 0.0333189
84 0.0328779
85 0.0324506
86 0.0320346
87 0.0316298
88 0.0312351
89 0.0308511
90 0.0304759
91 0.0301098
92 0.0297521
93 0.0294023
94 0.0290616
source: ['初め', 'まして', '。']
predict: ['初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。']
target: ['初め', 'まして', '。', 'よろしくお願いします', '。']
source: ['どこ', 'から', '来', 'た', 'ん', 'です', 'か', '?']
predict: ['日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来']
target: ['', 'から', '来', 'まし', 'た', '。']
source: ['', 'の', 'どこ', 'に', '住ん', 'でる', 'ん', 'です', 'か', '?']
predict: ['東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん']
target: ['', 'に', '住ん', 'で', 'い', 'ます', '。']
source: ['おはよう。']
predict: ['おはよう', 'ござい', 'ます', '。']
target: ['おはよう。']
source: ['いつも', '何', '時', 'に', '起き', 'ます', 'か', '?']
predict: ['6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に']
target: ['', '時', 'に', '起き', 'ます', '。']
source: ['', 'は', '何', 'を', '食べ', 'ます', 'か', '?']
predict: ['たいてい', 'トースト', 'と', '卵', 'を', '抜く', 'たいてい', 'トースト', 'と', '卵', 'を', '抜く', 'たいてい', 'トースト', 'と', '卵', 'を', '抜く', 'たいてい', 'トースト', 'と', '卵', 'を', '抜く', 'たいてい', 'トースト', 'と', '卵', 'を', '抜く']
target: ['たいてい', 'トースト', 'と', '卵', 'を', '食べ', 'ます', '。']
source: ['', 'は', '毎日', '食べ', 'ます', 'か', '?']
predict: ['たまに', '朝食', 'を', '抜く', 'たまに', '朝食', 'を', '抜く', 'たまに', '朝食', 'を', '抜く', 'たまに', '朝食', 'を', '抜く', 'たまに', '朝食', 'を', '抜く', 'たまに', '朝食', 'を', '抜く', 'たまに', '朝食', 'を', '抜く', 'たまに', '朝食']
target: ['たまに', '朝食', 'を', '抜く', 'こと', 'が', 'あり', 'ます', '。']
source: ['', 'を', 'たくさん', '取っ', 'て', 'い', 'ます', 'か', '?']
predict: ['毎日', '野菜', 'を', '取る', 'よう', 'に', '毎日', '野菜', 'を', '取る', 'よう', 'に', '毎日', '野菜', 'を', '取る', 'よう', 'に', '毎日', '野菜', 'を', '取る', 'よう', 'に', '毎日', '野菜', 'を', '取る', 'よう', 'に']
target: ['毎日', '', 'を', '取る', 'よう', 'に', 'し', 'て', 'い', 'ます', '。']
source: ['', '', 'は', '何', 'を', 'し', 'て', 'い', 'ます', 'か', '?']
predict: ['友達', 'と', '会っ', 'て', 'いる', 'と', '会っ', 'て', 'いる', 'と', '会っ', 'て', 'いる', 'と', '会っ', 'て', 'いる', 'と', '会っ', 'て', 'いる', 'と', '会っ', 'て', 'いる', 'と', '会っ', 'て', 'いる', 'と']
target: ['友達', 'と', '会っ', 'て', 'いる', 'こと', 'が', '多い', 'です', '。']
トレイニングデータは収束しているようですが,結果はまだまだダメそうですね。