chainerで実装してみた
こちらを参考に,実装してみた。
train_pair = [
["初めまして。", "初めまして。よろしくお願いします。"],
["どこから来たんですか?", "日本から来ました。"],
["日本のどこに住んでるんですか?", "東京に住んでいます。"],
["仕事は何してますか?", "私は会社員です。"],
["お会いできて嬉しかったです。", "私もです!"],
["おはよう。", "おはようございます。"],
["いつも何時に起きますか?", "6時に起きます。"],
["朝食は何を食べますか?", "たいていトーストと卵を食べます。"],
["朝食は毎日食べますか?", "たまに朝食を抜くことがあります。"],
["野菜をたくさん取っていますか?", "毎日野菜を取るようにしています。"],
["週末は何をしていますか?", "友達と会っていることが多いです。"],
["どこに行くのが好き?", "私たちは渋谷に行くのが好きです。"]
]
test_pair = [
["初めまして。", "初めまして。よろしくお願いします。"],
["どこから来たんですか?", "米国から来ました。"],
["米国のどこに住んでるんですか?", "ニューヨークに住んでいます。"],
["おはよう。", "おはよう。"],
["いつも何時に起きますか?", "7時に起きます。"],
["夕食は何を食べますか?", "たいていトーストと卵を食べます。"],
["夕食は毎日食べますか?", "たまに朝食を抜くことがあります。"],
["肉をたくさん取っていますか?", "毎日インクを取るようにしています。"],
["週頭は何をしていますか?", "友達と会っていることが多いです。"],
]
# https://nojima.hatenablog.com/entry/2017/10/10/023147
import MeCab
import chainer
import chainer.links as L
import chainer.functions as F
import random
import numpy as np
SIZE=10000
EOS=1
UNK=0
class EncoderDecoder(chainer.Chain):
def __init__(self, n_layer,n_vocab, n_out, n_hidden,dropout):
super(EncoderDecoder,self).__init__()
with self.init_scope():
self.embed_x = L.EmbedID(n_vocab, n_hidden)
self.embed_y = L.EmbedID(n_out,n_hidden)
self.encoder = L.NStepLSTM(
n_layers=n_layer,
in_size=n_hidden,
out_size=n_hidden,
dropout=dropout)
self.decoder = L.NStepLSTM(
n_layers=n_layer,
in_size=n_hidden,
out_size=n_hidden,
dropout=dropout)
self.W_C = L.Linear(2*n_hidden, n_hidden)
self.W_D = L.Linear(n_hidden, n_out)
def __call__(self, xs , ys ):
xs = [x[::-1] for x in xs]
eos = self.xp.array([EOS], dtype=np.int32)
ys_in = [F.concat((eos, y), axis=0) for y in ys]
ys_out = [F.concat((y, eos), axis=0) for y in ys]
# Both xs and ys_in are lists of arrays.
exs = [self.embed_x(x) for x in xs]
eys = [self.embed_y(y) for y in ys_in]
# hx:dimension x batchsize
# cx:dimension x batchsize
# yx:batchsize x timesize x dimension
hx, cx, yx = self.encoder(None, None, exs) # yxに全T方向ステップのyの出力,数はxsの長さと同じ,バッチごとにバラバラ
_, _, os = self.decoder(hx, cx, eys)
loss=0
for o,y,ey in zip(os,yx,ys_out): # バッチごとに処理
op=self._calculate_attention_layer_output(o,y)
loss+=F.softmax_cross_entropy(op,ey)
loss/=len(yx)
chainer.report({'loss': loss}, self)
return loss
def _calculate_attention_layer_output(self, embedded_output, attention):
inner_prod = F.matmul(embedded_output, attention, transb=True)
weights = F.softmax(inner_prod)
contexts = F.matmul(weights, attention)
concatenated = F.concat((contexts, embedded_output))
new_embedded_output = F.tanh(self.W_C(concatenated))
return self.W_D(new_embedded_output)
def translate(self,xs,max_length=30):
with chainer.no_backprop_mode(),chainer.using_config("train",False):
xs=xs[::-1] # reverse list
#exs = [self.embed_x(x) for x in xs]
exs = self.embed_x(xs)
hx, cx, yx = self.encoder(None, None, [exs])
predicts=[]
eos = self.xp.array([EOS], dtype=np.int32)
# EOSだけ入力,あとは予想した出力を入力にして繰り返す
for y in yx: # バッチ単位
predict=[]
ys_in=[eos]
for i in range(max_length):
eys = [self.embed_y(y) for y in ys_in]
_, _, os = self.decoder(hx, cx, eys)
op=self._calculate_attention_layer_output(os[0], y)
word_id=int(F.argmax(F.softmax(op)).data) # 単語IDに戻す
if word_id == EOS:break
predict.append(word_id)
ys_in=[self.xp.array([word_id], dtype=np.int32)]
predicts.append(np.array(predict))
return predict
class Data(chainer.dataset.DatasetMixin):
def __init__(self):
mecab = MeCab.Tagger("-Owakati")
self.vocab={"eos":0,"unk":1}
def to_dataset(source,target,train=True):
swords = to_number(mecab.parse(source).strip().split(" "),train)
twords = to_number(mecab.parse(target).strip().split(" "),train)
return (np.array(swords).astype(np.int32),np.array(twords).astype(np.int32))
def to_number(words,train):
ds=[]
for w in words:
if w not in self.vocab:
if train:
self.vocab[w]=len(self.vocab)
else:
w="unk"
ds.append(self.vocab[w])
return ds
self.train_data=[]
self.test_data=[]
for source,target in train_pair:
self.train_data.append(to_dataset(source,target))
for source,target in test_pair:
self.test_data.append(to_dataset(source,target,False))
self.vocab_inv={}
for w in self.vocab.keys():
self.vocab_inv[self.vocab[w]]=w
def convert(batch, device):
def to_device_batch(batch):
return [chainer.dataset.to_device(device, x) for x in batch]
res= {'xs': to_device_batch([x for x, _ in batch]),
'ys': to_device_batch([y for _, y in batch])}
return res
seed = 12345
random.seed(seed)
np.random.seed(seed)
data = Data()
batchsize=5
train_iter = chainer.iterators.SerialIterator(data.train_data,batchsize)
test_iter=chainer.iterators.SerialIterator(data.test_data,len(data.test_data))
n_vocab=len(data.vocab)
n_out=len(data.vocab)
n_hidden=300
n_layer=1
dropout=0.3
print("n_vocab:",n_vocab)
optimizer=chainer.optimizers.Adam()
mlp=EncoderDecoder(n_layer,n_vocab,n_out,n_hidden,dropout)
optimizer.setup(mlp)
updater=chainer.training.StandardUpdater(train_iter,optimizer,converter=convert,device=-1)
#train
epochs=20
trainer=chainer.training.Trainer(updater,(epochs,"epoch"),out="dialog_result")
trainer.extend(chainer.training.extensions.LogReport())
trainer.extend(chainer.training.extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy']))
trainer.run()
mlp.to_cpu()
chainer.serializers.save_npz("attention.model",mlp)
# https://nojima.hatenablog.com/entry/2017/10/17/034840
mlp=EncoderDecoder(n_layer,n_vocab,n_out,n_hidden,dropout)
chainer.serializers.load_npz("attention.model",mlp,path="")
for source,target in data.test_data:
predict=mlp.translate(np.array(source))
print("-----")
print("source:",[data.vocab_inv[w] for w in source])
print("predict:",[data.vocab_inv[w] for w in predict])
print("target:",[data.vocab_inv[w] for w in target])
実行結果
n_vocab: 71 epoch main/loss main/accuracy 1 4.00112 2 3.148 3 2.33681 4 1.62404 5 1.30338 6 0.93243 7 0.6091 8 0.3701 9 0.233166 10 0.202335 11 0.119416 12 0.0804442 13 0.0629114 14 0.0467078 15 0.032828 16 0.0285745 17 0.0225082 18 0.0183743 19 0.0152494 20 0.0140195 ----- source: ['初め', 'まして', '。'] predict: ['初め', 'まして', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします'] target: ['初め', 'まして', '。', 'よろしくお願いします', '。'] ----- source: ['どこ', 'から', '来', 'た', 'ん', 'です', 'か', '?'] predict: ['日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来'] target: ['unk', 'から', '来', 'まし', 'た', '。'] ----- source: ['unk', 'の', 'どこ', 'に', '住ん', 'でる', 'ん', 'です', 'か', '?'] predict: ['東京', 'に', '住ん', 'で', 'い', 'ます', '。'] target: ['unk', 'に', '住ん', 'で', 'い', 'ます', '。'] ----- source: ['おはよう。'] predict: ['おはよう', 'ござい', 'ます', '。'] target: ['おはよう。'] ----- source: ['いつも', '何', '時', 'に', '起き', 'ます', 'か', '?'] predict: ['6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に'] target: ['unk', '時', 'に', '起き', 'ます', '。'] ----- source: ['unk', 'は', '何', 'を', '食べ', 'ます', 'か', '?'] predict: ['たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と'] target: ['たいてい', 'トースト', 'と', '卵', 'を', '食べ', 'ます', '。'] ----- source: ['unk', 'は', '毎日', '食べ', 'ます', 'か', '?'] predict: ['たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く'] target: ['たまに', '朝食', 'を', '抜く', 'こと', 'が', 'あり', 'ます', '。'] ----- source: ['unk', 'を', 'たくさん', '取っ', 'て', 'い', 'ます', 'か', '?'] predict: ['毎日', '野菜', 'を', '取る', 'よう', 'に', 'し', 'て', 'い', 'ます', '。'] target: ['毎日', 'unk', 'を', '取る', 'よう', 'に', 'し', 'て', 'い', 'ます', '。'] ----- source: ['unk', 'unk', 'は', '何', 'を', 'し', 'て', 'い', 'ます', 'か', '?'] predict: ['友達', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と'] target: ['友達', 'と', '会っ', 'て', 'いる', 'こと', 'が', '多い', 'です', '。']
あまり,結果が良くないが,訓練データを増やし,Epochを増やすと良くなるかもしれない。