読者です 読者をやめる 読者になる 読者になる

chainerで点が円に内包されているか判定してみた ~~入門~~

 前回に引き続きchainer入門です。クラス分類やってみたいなと思ってやってみました。
 f:id:gasin:20170219233844p:plain
 円が目標の境界を示していて、点の色が分類結果です。
 ほんの少しずれていますが、ほぼ完全に分類できています。

 特に工夫したところなどもないので、コードを見てもらえればわかると思います。
 四層のニューラルネットです。

import numpy as np
import matplotlib.pyplot as plt
import chainer
from chainer import cuda, Function, gradient_check, Variable, optimizers, serializers, utils
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L

class MyChain(Chain):
    def __init__(self):
        super(MyChain, self).__init__(
            l1 = L.Linear(2, 20),
            l2 = L.Linear(20, 50),
            l3 = L.Linear(50, 2),
        )
    
    def __call__(self, x, y):
        return F.softmax_cross_entropy(self.fwd(x), y)
    
    def fwd(self,x):
        h1 = F.sigmoid(self.l1(x))
        h2 = F.sigmoid(self.l2(h1))
        h3 = self.l3(h2)
        return h3

if __name__ == "__main__":
    train_size = 1000
    test_size = 300
    
    train_x = (np.random.rand(train_size,2)-0.5).astype(np.float32)
    train_y = np.zeros((train_size,), dtype=np.int32)
    test_x = (np.random.rand(test_size,2)-0.5).astype(np.float32)
    
    for (i, train) in enumerate(train_x):
        if np.sqrt(train[0]*train[0]+train[1]*train[1]) < 0.3:
            train_y[i] = 1
            
    n_epoch = 1000
    n_batch = 100
    
    model = MyChain()
#    serializers.load_hdf5('MyChain.model', model)
    optimizer = optimizers.Adam()
    optimizer.setup(model)
    
    for epoch in range(n_epoch):
        print 'epoch : ', epoch
        sffindx = np.random.permutation(train_size)
        for i in range(n_batch):
            x = Variable(train_x[sffindx[i:(i+n_batch) if (i+n_batch) < train_size else train_size]])
            y = Variable(train_y[sffindx[i:(i+n_batch) if (i+n_batch) < train_size else train_size]])
            model.zerograds()
            loss = model(x, y)
            loss.backward()
            optimizer.update()
        print 'loss : ', loss.data
        
    xt = Variable(test_x, volatile='on')
    yy = model.fwd(xt)
    ans = yy.data
    ans_row, ans_col = ans.shape
        
    for i in range(ans_row):
        cls = np.argmax(ans[i, :])
#        print np.max(ans[i, :]), np.sum(ans[i, :])
        if cls == 0:
            plt.plot(test_x[i][0], test_x[i][1], "bo")
        else:
            plt.plot(test_x[i][0], test_x[i][1], "ro")
        
    circle = plt.Circle((0.0,0.0), 0.3)
    fig = plt.gcf()
    ax = fig.gca()
    ax.add_artist(circle)
    
    filename = "output.png"
    plt.xlabel("x")
    plt.ylabel("y")
    plt.savefig(filename)
    plt.show()
    serializers.save_hdf5('MyChain.model', model)