このあたりを参考に書き直してみました。
CRFなどの解説はリンク先を参照のこと
仕様
単語列から名詞の塊を抜き出す。
英語で複合名詞など複数の名詞で構成される名詞の塊にフラグをつける。
下記の例では、Oが名詞以外、Bが名詞の開始位置、Iが複合名詞の2個目以降を示している。
the wall street journal reported today that apple corporation made moneyO B I I I O O O B I O O O
これを、Linearの1層で学習し、CRFで出力する
コード
import numpy as np
import argparse
import os
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training, optimizers, initializers,reporter
from chainer.training import extensions
import random
from itertools import chain
# my model
class IconDetector(chainer.Chain):
def __init__(self, n_vocab, window, n_label, n_unit):
super().__init__()
with self.init_scope():
self.embed=L.EmbedID(n_vocab, n_unit)
self.lin=L.Linear(n_unit , n_label)
self.crf=L.CRF1d(n_label=n_label)
self.window=window
def forward(self, xs):
l = xs.shape[1]
ys = []
# 1 wordづつ
for i in range(l):
x = self.embed(xs[:,i])
h = F.tanh(x)
y = self.lin(h)
ys.append(y)
return ys #[window, batchsize, n_label]
def __call__(self, xs, ts):
"""error function"""
ys = self.forward(xs)
ts = [ts[:, i] for i in range(ts.data.shape[1])] # [window,batchsize]
loss = self.crf(ys, ts)
reporter.report({'loss': loss}, self)
return loss
def predict(self, xs):
ts = self.forward(xs)
_, ys = self.crf.argmax(ts)
return ys
class WindowIterator(chainer.dataset.Iterator):
def __init__(self, text, label, window, batch_size, shuffle= True,repeat=True):
self.text = np.asarray(text, dtype=np.int32)
self.label = np.asarray(label, dtype=np.int32)
self.window = window
self.batch_size = batch_size
self._repeat = repeat
self._shuffle=shuffle
if self._shuffle:
self.order = np.random.permutation(
len(text) - window ).astype(np.int32)
else:
self.order=np.array(list(range(len(text) - window )))
self.current_position = 0
self.epoch = 0
self.is_new_epoch = False
def __next__(self):
if not self._repeat and self.epoch > 0:
raise StopIteration
i = self.current_position
i_end = i + self.batch_size
position = self.order[i: i_end]
offset = np.concatenate([np.arange(0, self.window )])
pos = position[:, None] + offset[None, :]
context = self.text.take(pos)
doc = self.label.take(pos)
if i_end >= len(self.order):
np.random.shuffle(self.order)
self.epoch += 1
self.is_new_epoch = True
self.current_position = 0
else:
self.is_new_epoch = False
self.current_position = i_end
return context, doc
@property
def epoch_detail(self):
return self.epoch + float(self.current_position) / len(self.order)
@chainer.dataset.converter()
def convert(batch, device):
context, doc = batch
xp = device.xp
doc = xp.asarray(doc)
context = xp.asarray(context)
return context, doc
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--device', '-d', type=str, default='-1',
help='Device specifier. Either ChainerX device '
'specifier or an integer. If non-negative integer, '
'CuPy arrays with specified device id are used. If '
'negative integer, NumPy arrays are used')
parser.add_argument('--window', '-w', default=5, type=int,
help='window size')
parser.add_argument('--batchsize', '-b', type=int, default=2,
help='learning minibatch size')
parser.add_argument('--epoch', '-e', default=100, type=int,
help='number of epochs to learn')
parser.add_argument('--out', default=os.environ["WORK"]+"/summarization/sm_icgw/result",
help='Directory to output the result')
args = parser.parse_args()
random.seed(12345)
np.random.seed(12345)
device = chainer.get_device(args.device)
training_data = [(
" the wall street journal reported today that apple corporation made money ".split(),
"O B I I I O O O B I O O O".split()
), (
" georgia tech is a university in georgia ".split(),
"O B I O O O O B O".split()
)]
validation_data = [(' georgia tech reported today '.split(),"O B I O O O".split())]
word2id={"":0,"":1,"":2,"":3}
label2id={"B":0,"I":1,"O":2}
def get_dataset(data,is_train=True):
texts=[]
labels=[]
for word,attrib in data:
for w,a in zip(word,attrib):
if w not in word2id:
if is_train:
word2id[w]=len(word2id)
else:
w=""
texts.append(word2id[w])
labels.append(label2id[a])
return texts,labels
train_text,train_label=get_dataset(training_data)
valid_text,valid_label=get_dataset(validation_data,False)
n_vocab=len(word2id)
n_label=len(label2id)
n_unit=n_vocab//2
model=IconDetector(n_vocab=n_vocab, window=args.window,n_label=n_label,n_unit=n_unit)
model.to_device(device)
optimizer = optimizers.Adam()
optimizer.setup(model)
train_iter = WindowIterator(train_text, train_label, args.window, args.batchsize)
valid_iter = WindowIterator(valid_text, valid_label, args.window,args.batchsize, repeat=False, shuffle=False)
updater = training.StandardUpdater(train_iter, optimizer, converter=convert, device=device)
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
trainer.extend(extensions.LogReport(),trigger=(10, 'epoch'))
trainer.extend(extensions.PrintReport(['epoch', 'main/loss','validation/main/loss']))
#trainer.extend(extensions.ProgressBar())
trainer.extend(extensions.Evaluator(valid_iter, model,converter=convert, device=device),trigger=(10, 'epoch'))
trainer.run()
# testing
testing_data = [(' the street journal is a university '.split(),"O B I I O O O O".split())]
test_text,test_label=get_dataset(testing_data,False)
with chainer.using_config('train', False), \
chainer.using_config('enable_backprop', False):
ys=model.predict(np.array([test_text],dtype=np.int32))
ys=list(chain.from_iterable(ys))
print(ys)
print(test_label)
if __name__ == '__main__':
main()
実行結果
epoch main/loss validation/main/loss 10 6.24186 7.77101 20 5.43089 6.29938 30 3.16409 5.09233 40 2.68146 4.02053 50 1.47248 3.07853 60 1.63073 2.28975 70 1.13568 1.67798 80 0.842947 1.22796 90 0.364242 0.912943 100 0.658557 0.692736 [2, 0, 1, 1, 2, 2, 2, 2] [2, 0, 1, 1, 2, 2, 2, 2]
ちゃんと学習できているようです