博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
可以多分类的神经网络
阅读量:6528 次
发布时间:2019-06-24

本文共 3831 字,大约阅读时间需要 12 分钟。

 

'''11行神经网络① 层数可变,多类'''import numpy as npimport matplotlib.pyplot as pltclass BPNN(object):    def __init__(self, neurals, epsilon=0.01, lamda=0.01):        if not isinstance(neurals, list): raise '出错:参数 neurals 必须是 list 类型'        self.neurals = neurals   # 各层神经元数目,如:[3,10,8,2]        self.size = len(neurals)                self.epsilon = epsilon  # 学习速率        self.lamda = lamda      # 正则化强度        self.w = [np.random.randn(i,j) for i,j in zip(neurals[:-1], neurals[1:])] + [None]        self.b = [None] + [np.random.randn(1,j) for j in neurals[1:]]        self.l = [None] * self.size        self.l_delta = [None] * self.size                self.probs = None            # 前向传播    def forward(self, X):        self.l[0] = X        for i in range(1, self.size-1):            self.l[i] = np.tanh(np.dot(self.l[i-1], self.w[i-1]) + self.b[i]) # tanh 函数            self.l[-1] = np.exp(np.dot(self.l[-2], self.w[-2]) + self.b[-1])        self.probs = self.l[-1] / np.sum(self.l[-1], axis=1, keepdims=True)            # 后向传播    def backward(self, y):        self.l_delta[-1] = np.copy(self.probs)        self.l_delta[-1][range(self.n_samples), y] -= 1        for i in range(self.size-2, 0, -1):            self.l_delta[i] = np.dot(self.l_delta[i+1], self.w[i].T) * (1 - np.power(self.l[i], 2)) # tanh 函数的导数                # 更新权值、偏置    def update(self):        self.b[-1] -= self.epsilon * np.sum(self.l_delta[-1], axis=0, keepdims=True)        for i in range(self.size-2, -1, -1):            self.w[i] -= self.epsilon * (np.dot(self.l[i].T, self.l_delta[i+1]) + self.lamda * self.w[i])            if i == 0: break            self.b[i] -= self.epsilon * np.sum(self.l_delta[i], axis=0)        # 计算损失    def calculate_loss(self, y):        loss = np.sum(-np.log(self.probs[range(self.n_samples), y]))        loss += self.lamda/2 * np.sum([np.sum(np.square(wi)) for wi in self.w[:-1]]) # 可选        loss *= 1/self.n_samples  # 可选        return loss        # 拟合    def fit(self, X, y, n_iter=1000, print_loss=True):        self.n_samples = X.shape[0] # 样本大小(样本数目)                for i in range(n_iter):            self.forward(X)            self.backward(y)            self.update()                        if not print_loss: continue            if i%100 == 0: print(self.calculate_loss(y))        # 预测    def predict(self, x):        self.forward(x)        return np.argmax(self.probs, axis=1)        def plot_decision_boundary(clf, X, y):    # Set min and max values and give it some padding    x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5    y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5    h = 0.01    # Generate a grid of points with distance h between them    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))    # Predict the function value for the whole gid    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])    Z = Z.reshape(xx.shape)    # Plot the contour and training examples    plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)    plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Spectral)        def test1():    X = np.array([[0,0,1],[0,1,1],[1,0,1],[1,1,1]])    #y = np.array([0,1,1,0]) # 两类    y = np.array([0,1,2,3])  # 多类        # bpnn = BPNN([3, 10, 8, 1])    bpnn = BPNN([3, 10, 8, 4])    bpnn.fit(X, y, n_iter=1000)        print('训练结果:', bpnn.predict(X))        def test2():    from sklearn.datasets import make_moons    from sklearn.linear_model import LogisticRegressionCV        X, y = make_moons(200, noise=0.20)    plt.scatter(X[:,0], X[:,1], s=40, c=y, cmap=plt.cm.Spectral)    plt.show()    clf = LogisticRegressionCV()    clf.fit(X, y)    plot_decision_boundary(clf, X, y)    plt.show()    #nn = BPNN([2,5,4,2])    nn = BPNN([2,4,2])    nn.fit(X, y, n_iter=1000)    plot_decision_boundary(nn, X, y)    plt.show()            if __name__ == '__main__':    #test1()    test2()

 

转载地址:http://mxtbo.baihongyu.com/

你可能感兴趣的文章
优化LibreOffice如此简单
查看>>
【Oracle 数据迁移】环境oracle 11gR2,exp无法导出空表的表结构【转载】
查看>>
秒杀系统设计方案
查看>>
3D印花芭蕾舞鞋为舞者科学地保护双脚
查看>>
冲浪科技获Ventech China数百万美元天使轮融资,发力自动驾驶行业
查看>>
通过ActionTrail监控AccessKey的使用
查看>>
从 JavaScript 到 TypeScript
查看>>
一个mysql复制中断的案例
查看>>
【最佳实践】OSS开源工具ossutil-大文件断点续传
查看>>
Linux常用的服务器构建
查看>>
深入了解 Weex
查看>>
Android第三方开源FloatingActionButton(com.getbase.floatingactionbutton)【1】
查看>>
【75位联合作者Nature重磅】AI药神:机器学习模型有望提前五年预测白血病!
查看>>
精通SpringBoot——第二篇:视图解析器,静态资源和区域配置
查看>>
JavaScript基础(六)面向对象
查看>>
总结几点Quartz的经验
查看>>
物联网、自动化的冲击下未来20年职场六大趋势
查看>>
《Java核心技术 卷Ⅱ 高级特性(原书第10版)》一3.6.2 使用StAX解析器
查看>>
9月26日云栖精选夜读:阿里Java代码规约插件即将全球首发,邀您来发布仪式现场...
查看>>
北京市交管局联合高德地图发布北京中考出行提示
查看>>