chainerで点が円に内包されているか判定してみた ~~入門~~
前回に引き続きchainer入門です。クラス分類やってみたいなと思ってやってみました。
円が目標の境界を示していて、点の色が分類結果です。
ほんの少しずれていますが、ほぼ完全に分類できています。
特に工夫したところなどもないので、コードを見てもらえればわかると思います。
四層のニューラルネットです。
import numpy as np import matplotlib.pyplot as plt import chainer from chainer import cuda, Function, gradient_check, Variable, optimizers, serializers, utils from chainer import Link, Chain, ChainList import chainer.functions as F import chainer.links as L class MyChain(Chain): def __init__(self): super(MyChain, self).__init__( l1 = L.Linear(2, 20), l2 = L.Linear(20, 50), l3 = L.Linear(50, 2), ) def __call__(self, x, y): return F.softmax_cross_entropy(self.fwd(x), y) def fwd(self,x): h1 = F.sigmoid(self.l1(x)) h2 = F.sigmoid(self.l2(h1)) h3 = self.l3(h2) return h3 if __name__ == "__main__": train_size = 1000 test_size = 300 train_x = (np.random.rand(train_size,2)-0.5).astype(np.float32) train_y = np.zeros((train_size,), dtype=np.int32) test_x = (np.random.rand(test_size,2)-0.5).astype(np.float32) for (i, train) in enumerate(train_x): if np.sqrt(train[0]*train[0]+train[1]*train[1]) < 0.3: train_y[i] = 1 n_epoch = 1000 n_batch = 100 model = MyChain() # serializers.load_hdf5('MyChain.model', model) optimizer = optimizers.Adam() optimizer.setup(model) for epoch in range(n_epoch): print 'epoch : ', epoch sffindx = np.random.permutation(train_size) for i in range(n_batch): x = Variable(train_x[sffindx[i:(i+n_batch) if (i+n_batch) < train_size else train_size]]) y = Variable(train_y[sffindx[i:(i+n_batch) if (i+n_batch) < train_size else train_size]]) model.zerograds() loss = model(x, y) loss.backward() optimizer.update() print 'loss : ', loss.data xt = Variable(test_x, volatile='on') yy = model.fwd(xt) ans = yy.data ans_row, ans_col = ans.shape for i in range(ans_row): cls = np.argmax(ans[i, :]) # print np.max(ans[i, :]), np.sum(ans[i, :]) if cls == 0: plt.plot(test_x[i][0], test_x[i][1], "bo") else: plt.plot(test_x[i][0], test_x[i][1], "ro") circle = plt.Circle((0.0,0.0), 0.3) fig = plt.gcf() ax = fig.gca() ax.add_artist(circle) filename = "output.png" plt.xlabel("x") plt.ylabel("y") plt.savefig(filename) plt.show() serializers.save_hdf5('MyChain.model', model)