このあたりを参考に書き直してみました。
CRFなどの解説はリンク先を参照のこと
仕様
単語列から名詞の塊を抜き出す。
英語で複合名詞など複数の名詞で構成される名詞の塊にフラグをつける。
下記の例では、Oが名詞以外、Bが名詞の開始位置、Iが複合名詞の2個目以降を示している。
1 2 | <s> the wall street journal reported today that apple corporation made money </s> O B I I I O O O B I O O O |
これを、Linearの1層で学習し、CRFで出力する
コード
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | import numpy as np import argparse import os import chainer import chainer.functions as F import chainer.links as L from chainer import training, optimizers, initializers,reporter from chainer.training import extensions import random from itertools import chain # my model class IconDetector(chainer.Chain): def __init__(self, n_vocab, window, n_label, n_unit): super().__init__() with self.init_scope(): self.embed=L.EmbedID(n_vocab, n_unit) self.lin=L.Linear(n_unit , n_label) self.crf=L.CRF1d(n_label=n_label) self.window=window def forward(self, xs): l = xs.shape[1] ys = [] # 1 wordづつ for i in range(l): x = self.embed(xs[:,i]) h = F.tanh(x) y = self.lin(h) ys.append(y) return ys #[window, batchsize, n_label] def __call__(self, xs, ts): """error function""" ys = self.forward(xs) ts = [ts[:, i] for i in range(ts.data.shape[1])] # [window,batchsize] loss = self.crf(ys, ts) reporter.report({'loss': loss}, self) return loss def predict(self, xs): ts = self.forward(xs) _, ys = self.crf.argmax(ts) return ys class WindowIterator(chainer.dataset.Iterator): def __init__(self, text, label, window, batch_size, shuffle= True,repeat=True): self.text = np.asarray(text, dtype=np.int32) self.label = np.asarray(label, dtype=np.int32) self.window = window self.batch_size = batch_size self._repeat = repeat self._shuffle=shuffle if self._shuffle: self.order = np.random.permutation( len(text) - window ).astype(np.int32) else: self.order=np.array(list(range(len(text) - window ))) self.current_position = 0 self.epoch = 0 self.is_new_epoch = False def __next__(self): if not self._repeat and self.epoch > 0: raise StopIteration i = self.current_position i_end = i + self.batch_size position = self.order[i: i_end] offset = np.concatenate([np.arange(0, self.window )]) pos = position[:, None] + offset[None, :] context = self.text.take(pos) doc = self.label.take(pos) if i_end >= len(self.order): np.random.shuffle(self.order) self.epoch += 1 self.is_new_epoch = True self.current_position = 0 else: self.is_new_epoch = False self.current_position = i_end return context, doc @property def epoch_detail(self): return self.epoch + float(self.current_position) / len(self.order) @chainer.dataset.converter() def convert(batch, device): context, doc = batch xp = device.xp doc = xp.asarray(doc) context = xp.asarray(context) return context, doc def main(): parser = argparse.ArgumentParser() parser.add_argument('--device', '-d', type=str, default='-1', help='Device specifier. Either ChainerX device ' 'specifier or an integer. If non-negative integer, ' 'CuPy arrays with specified device id are used. If ' 'negative integer, NumPy arrays are used') parser.add_argument('--window', '-w', default=5, type=int, help='window size') parser.add_argument('--batchsize', '-b', type=int, default=2, help='learning minibatch size') parser.add_argument('--epoch', '-e', default=100, type=int, help='number of epochs to learn') parser.add_argument('--out', default=os.environ["WORK"]+"/summarization/sm_icgw/result", help='Directory to output the result') args = parser.parse_args() random.seed(12345) np.random.seed(12345) device = chainer.get_device(args.device) training_data = [( "<s> the wall street journal reported today that apple corporation made money </s>".split(), "O B I I I O O O B I O O O".split() ), ( "<s> georgia tech is a university in georgia </s>".split(), "O B I O O O O B O".split() )] validation_data = [('<s> georgia tech reported today </s>'.split(),"O B I O O O".split())] word2id={"<s>":0,"</s>":1,"<unk>":2,"<pad>":3} label2id={"B":0,"I":1,"O":2} def get_dataset(data,is_train=True): texts=[] labels=[] for word,attrib in data: for w,a in zip(word,attrib): if w not in word2id: if is_train: word2id[w]=len(word2id) else: w="<unk>" texts.append(word2id[w]) labels.append(label2id[a]) return texts,labels train_text,train_label=get_dataset(training_data) valid_text,valid_label=get_dataset(validation_data,False) n_vocab=len(word2id) n_label=len(label2id) n_unit=n_vocab//2 model=IconDetector(n_vocab=n_vocab, window=args.window,n_label=n_label,n_unit=n_unit) model.to_device(device) optimizer = optimizers.Adam() optimizer.setup(model) train_iter = WindowIterator(train_text, train_label, args.window, args.batchsize) valid_iter = WindowIterator(valid_text, valid_label, args.window,args.batchsize, repeat=False, shuffle=False) updater = training.StandardUpdater(train_iter, optimizer, converter=convert, device=device) trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) trainer.extend(extensions.LogReport(),trigger=(10, 'epoch')) trainer.extend(extensions.PrintReport(['epoch', 'main/loss','validation/main/loss'])) #trainer.extend(extensions.ProgressBar()) trainer.extend(extensions.Evaluator(valid_iter, model,converter=convert, device=device),trigger=(10, 'epoch')) trainer.run() # testing testing_data = [('<s> the street journal is a university </s>'.split(),"O B I I O O O O".split())] test_text,test_label=get_dataset(testing_data,False) with chainer.using_config('train', False), \ chainer.using_config('enable_backprop', False): ys=model.predict(np.array([test_text],dtype=np.int32)) ys=list(chain.from_iterable(ys)) print(ys) print(test_label) if __name__ == '__main__': main() |
実行結果
1 2 3 4 5 6 7 8 9 10 11 12 13 | epoch main/loss validation/main/loss 10 6.24186 7.77101 20 5.43089 6.29938 30 3.16409 5.09233 40 2.68146 4.02053 50 1.47248 3.07853 60 1.63073 2.28975 70 1.13568 1.67798 80 0.842947 1.22796 90 0.364242 0.912943 100 0.658557 0.692736 [2, 0, 1, 1, 2, 2, 2, 2] [2, 0, 1, 1, 2, 2, 2, 2] |
ちゃんと学習できているようです