こちら
で紹介しているSeq2Seqを動く形で作成しなおしてみました。
train_pair = [
["初めまして。", "初めまして。よろしくお願いします。"],
["どこから来たんですか?", "日本から来ました。"],
["日本のどこに住んでるんですか?", "東京に住んでいます。"],
["仕事は何してますか?", "私は会社員です。"],
["お会いできて嬉しかったです。", "私もです!"],
["おはよう。", "おはようございます。"],
["いつも何時に起きますか?", "6時に起きます。"],
["朝食は何を食べますか?", "たいていトーストと卵を食べます。"],
["朝食は毎日食べますか?", "たまに朝食を抜くことがあります。"],
["野菜をたくさん取っていますか?", "毎日野菜を取るようにしています。"],
["週末は何をしていますか?", "友達と会っていることが多いです。"],
["どこに行くのが好き?", "私たちは渋谷に行くのが好きです。"]
]
test_pair = [
["初めまして。", "初めまして。よろしくお願いします。"],
["どこから来たんですか?", "米国から来ました。"],
["米国のどこに住んでるんですか?", "ニューヨークに住んでいます。"],
["おはよう。", "おはよう。"],
["いつも何時に起きますか?", "7時に起きます。"],
["夕食は何を食べますか?", "たいていトーストと卵を食べます。"],
["夕食は毎日食べますか?", "たまに朝食を抜くことがあります。"],
["肉をたくさん取っていますか?", "毎日インクを取るようにしています。"],
["週頭は何をしていますか?", "友達と会っていることが多いです。"],
]
# https://nojima.hatenablog.com/entry/2017/10/10/023147
import nltk
import MeCab
import chainer
from chainer.datasets import split_dataset_random , cifar
import chainer.links as L
import chainer.functions as F
import random
import numpy as np
import codecs
SIZE=10000
EOS=1
UNK=0
class EncoderDecoder(chainer.Chain):
def __init__(self, n_vocab, n_out, n_hidden):
super(EncoderDecoder,self).__init__()
with self.init_scope():
self.embed_x = L.EmbedID(n_vocab, n_hidden)
self.embed_y = L.EmbedID(n_out,n_hidden)
self.encoder = L.NStepLSTM(
n_layers=1,
in_size=n_hidden,
out_size=n_hidden,
dropout=0.1)
self.decoder = L.NStepLSTM(
n_layers=1,
in_size=n_hidden,
out_size=n_hidden,
dropout=0.1)
self.W = L.Linear(n_hidden, n_out)
def __call__(self, xs , ys ):
xs = [x[::-1] for x in xs]
eos = self.xp.array([EOS], dtype=np.int32)
ys_in = [F.concat((eos, y), axis=0) for y in ys]
ys_out = [F.concat((y, eos), axis=0) for y in ys]
# Both xs and ys_in are lists of arrays.
exs = [self.embed_x(x) for x in xs]
eys = [self.embed_y(y) for y in ys_in]
hx, cx, _ = self.encoder(None, None, exs)
_, _, os = self.decoder(hx, cx, eys)
batch = len(xs)
concat_os = F.concat(os, axis=0)
concat_ys_out = F.concat(ys_out, axis=0)
loss = F.sum(F.softmax_cross_entropy(self.W(concat_os), concat_ys_out, reduce='no')) / batch
chainer.report({'loss': loss}, self)
n_words = concat_ys_out.shape[0]
perp = self.xp.exp(loss.array * batch / n_words)
chainer.report({'perp': perp}, self)
return loss
def translate(self,xs,max_length=30):
with chainer.no_backprop_mode(),chainer.using_config("train",False):
xs=xs[::-1] # reverse list
#exs = [self.embed_x(x) for x in xs]
exs = self.embed_x(xs)
hx, cx, _ = self.encoder(None, None, [exs])
result=[]
word_id=EOS
for i in range(max_length):
os,cs=self._translate_word(word_id,hx,cx)
word_id=int(F.argmax(os).data)
if word_id==EOS:break
result.append(word_id)
return result
def _translate_word(self,word_id,hx,cx):
y=np.array([word_id],dtype=np.int32)
ey=self.embed_y(y)
_, cs, os = self.decoder(hx, cx, [ey])
fos=F.softmax(self.W(os[0]))
return fos,cs
class Data(chainer.dataset.DatasetMixin):
def __init__(self):
mecab = MeCab.Tagger("-Owakati")
self.vocab={"":0,"":1}
def to_dataset(source,target,train=True):
swords = to_number(mecab.parse(source).strip().split(" "),train)
twords = to_number(mecab.parse(target).strip().split(" "),train)
return (np.array(swords).astype(np.int32),np.array(twords).astype(np.int32))
def to_number(words,train):
ds=[]
for w in words:
if w not in self.vocab:
if train:
self.vocab[w]=len(self.vocab)
else:
w=""
ds.append(self.vocab[w])
return ds
self.train_data=[]
self.test_data=[]
for source,target in train_pair:
self.train_data.append(to_dataset(source,target))
for source,target in test_pair:
self.test_data.append(to_dataset(source,target,False))
self.vocab_inv={}
for w in self.vocab.keys():
self.vocab_inv[self.vocab[w]]=w
def convert(batch, device):
def to_device_batch(batch):
return [chainer.dataset.to_device(device, x) for x in batch]
res= {'xs': to_device_batch([x for x, _ in batch]),
'ys': to_device_batch([y for _, y in batch])}
return res
seed = 12345
random.seed(seed)
np.random.seed(seed)
data = Data()
batchsize=128
#train_iter = chainer.iterators.MultithreadIterator(train,batchsize,n_threads=4)
#test_iter=chainer.iterators.MultithreadIterator(test,len(test),repeat=False,shuffle=False,n_threads=4)
train_iter = chainer.iterators.SerialIterator(data.train_data,batchsize)
test_iter=chainer.iterators.SerialIterator(data.test_data,len(data.test_data))
n_vocab=len(data.vocab)
n_out=len(data.vocab)
n_hidden=300
print("n_vocab:",n_vocab)
optimizer=chainer.optimizers.Adam()
mlp=EncoderDecoder(n_vocab,n_out,n_hidden)
optimizer.setup(mlp)
updater=chainer.training.StandardUpdater(train_iter,optimizer,converter=convert,device=-1)
#train
epochs=1000
trainer=chainer.training.Trainer(updater,(epochs,"epoch"),out="dialog_result")
trainer.extend(chainer.training.extensions.LogReport())
trainer.extend(chainer.training.extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy']))
trainer.run()
mlp.to_cpu()
chainer.serializers.save_npz("dialog.model",mlp)
# https://nojima.hatenablog.com/entry/2017/10/17/034840
mlp=EncoderDecoder(n_vocab,n_out,n_hidden)
chainer.serializers.load_npz("dialog.model",mlp,path="")
for source,target in data.test_data:
predict=mlp.translate(np.array(source))
print("source:",[data.vocab_inv[w] for w in source])
print("predict:",[data.vocab_inv[w] for w in predict])
print("target:",[data.vocab_inv[w] for w in target])
結果
n_vocab: 71 epoch main/loss main/accuracy 1 34.5502 2 31.7847 3 29.1603 4 26.6713 5 24.3079 6 22.0612 7 19.9224 8 17.8789 9 15.919 10 14.0387 11 12.2457 12 10.5555 13 8.98597 14 7.55335 15 6.26964 16 5.14105 17 4.16823 18 3.34746 19 2.67084 20 2.12592 21 1.69575 22 1.36043 23 1.10025 24 0.898079 25 0.740202 26 0.616168 27 0.518174 28 0.44032 29 0.378046 30 0.327818 31 0.286931 32 0.253339 33 0.225496 34 0.202227 35 0.18263 36 0.166004 37 0.151796 38 0.139575 39 0.128993 40 0.119775 41 0.1117 42 0.104586 43 0.09829 44 0.0926925 45 0.0876961 46 0.0832203 47 0.0791977 48 0.0755707 49 0.0722914 50 0.0693159 51 0.0666087 52 0.0641389 53 0.0618778 54 0.0598021 55 0.0578909 56 0.0561267 57 0.0544931 58 0.0529775 59 0.051567 60 0.0502504 61 0.0490188 62 0.0478641 63 0.0467791 64 0.0457571 65 0.0447925 66 0.0438801 67 0.0430147 68 0.0421937 69 0.0414122 70 0.0406681 71 0.039957 72 0.0392778 73 0.0386279 74 0.0380038 75 0.037405 76 0.0368287 77 0.0362744 78 0.0357389 79 0.0352229 80 0.0347236 81 0.0342409 82 0.0337726 83 0.0333189 84 0.0328779 85 0.0324506 86 0.0320346 87 0.0316298 88 0.0312351 89 0.0308511 90 0.0304759 91 0.0301098 92 0.0297521 93 0.0294023 94 0.0290616 source: ['初め', 'まして', '。'] predict: ['初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。', '初め', 'まして', '。'] target: ['初め', 'まして', '。', 'よろしくお願いします', '。'] source: ['どこ', 'から', '来', 'た', 'ん', 'です', 'か', '?'] predict: ['日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来'] target: ['トレイニングデータは収束しているようですが,結果はまだまだダメそうですね。', 'から', '来', 'まし', 'た', '。'] source: [' ', 'の', 'どこ', 'に', '住ん', 'でる', 'ん', 'です', 'か', '?'] predict: ['東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん', '東京', 'に', '住ん'] target: [' ', 'に', '住ん', 'で', 'い', 'ます', '。'] source: ['おはよう。'] predict: ['おはよう', 'ござい', 'ます', '。'] target: ['おはよう。'] source: ['いつも', '何', '時', 'に', '起き', 'ます', 'か', '?'] predict: ['6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に'] target: [' ', '時', 'に', '起き', 'ます', '。'] source: [' ', 'は', '何', 'を', '食べ', 'ます', 'か', '?'] predict: ['たいてい', 'トースト', 'と', '卵', 'を', '抜く', 'たいてい', 'トースト', 'と', '卵', 'を', '抜く', 'たいてい', 'トースト', 'と', '卵', 'を', '抜く', 'たいてい', 'トースト', 'と', '卵', 'を', '抜く', 'たいてい', 'トースト', 'と', '卵', 'を', '抜く'] target: ['たいてい', 'トースト', 'と', '卵', 'を', '食べ', 'ます', '。'] source: [' ', 'は', '毎日', '食べ', 'ます', 'か', '?'] predict: ['たまに', '朝食', 'を', '抜く', 'たまに', '朝食', 'を', '抜く', 'たまに', '朝食', 'を', '抜く', 'たまに', '朝食', 'を', '抜く', 'たまに', '朝食', 'を', '抜く', 'たまに', '朝食', 'を', '抜く', 'たまに', '朝食', 'を', '抜く', 'たまに', '朝食'] target: ['たまに', '朝食', 'を', '抜く', 'こと', 'が', 'あり', 'ます', '。'] source: [' ', 'を', 'たくさん', '取っ', 'て', 'い', 'ます', 'か', '?'] predict: ['毎日', '野菜', 'を', '取る', 'よう', 'に', '毎日', '野菜', 'を', '取る', 'よう', 'に', '毎日', '野菜', 'を', '取る', 'よう', 'に', '毎日', '野菜', 'を', '取る', 'よう', 'に', '毎日', '野菜', 'を', '取る', 'よう', 'に'] target: ['毎日', ' ', 'を', '取る', 'よう', 'に', 'し', 'て', 'い', 'ます', '。'] source: [' ', ' ', 'は', '何', 'を', 'し', 'て', 'い', 'ます', 'か', '?'] predict: ['友達', 'と', '会っ', 'て', 'いる', 'と', '会っ', 'て', 'いる', 'と', '会っ', 'て', 'いる', 'と', '会っ', 'て', 'いる', 'と', '会っ', 'て', 'いる', 'と', '会っ', 'て', 'いる', 'と', '会っ', 'て', 'いる', 'と'] target: ['友達', 'と', '会っ', 'て', 'いる', 'こと', 'が', '多い', 'です', '。']