博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
莫烦课程Batch Normalization 批标准化
阅读量:5955 次
发布时间:2019-06-19

本文共 1146 字,大约阅读时间需要 3 分钟。

for i in range(N_HIDDEN):               # build hidden layers and BN layers            input_size = 1 if i == 0 else 10            fc = nn.Linear(input_size, 10)            setattr(self, 'fc%i' % i, fc)       # IMPORTANT set layer to the Module            self._set_init(fc)                  # parameters initialization            self.fcs.append(fc)            if self.do_bn:                bn = nn.BatchNorm1d(10, momentum=0.5)                setattr(self, 'bn%i' % i, bn)   # IMPORTANT set layer to the Moduleself.bns.append(bn)

 上面的代码对每个隐层进行批标准化,setattr(self, 'fc%i' % i, fc)作用相当于self.fci=fc

每次生成的结果append到bns的最后面,结果的size 10×10,取出这些数据是非常方便

def forward(self, x):        pre_activation = [x]        if self.do_bn: x = self.bn_input(x)     # input batch normalization        layer_input = [x]        for i in range(N_HIDDEN):            x = self.fcs[i](x)            pre_activation.append(x)            if self.do_bn: x = self.bns[i](x)   # batch normalization            x = ACTIVATION(x)            layer_input.append(x)        out = self.predict(x)return out, layer_input, pre_activation

全部的

 

转载于:https://www.cnblogs.com/lindaxin/p/8034069.html

你可能感兴趣的文章
做程序开发的你如果经常用Redis,这些问题肯定会遇到
查看>>
CAS-认证流程
查看>>
006android初级篇之jni数据类型映射
查看>>
Java 集合框架查阅技巧
查看>>
apache配置虚拟主机
查看>>
CollectionView水平和竖直瀑布流的实现
查看>>
前端知识复习一(css)
查看>>
spark集群启动步骤及web ui查看
查看>>
利用WCF改进文件流传输的三种方式
查看>>
Spring学习总结(2)——Spring的常用注解
查看>>
关于IT行业人员吃的都是青春饭?[透彻讲解]
查看>>
钱到用时方恨少(随记)
查看>>
mybatis主键返回的实现
查看>>
org.openqa.selenium.StaleElementReferenceException
查看>>
数论之 莫比乌斯函数
查看>>
linux下查找某个文件位置的方法
查看>>
python之MySQL学习——数据操作
查看>>
Harmonic Number (II)
查看>>
长连接、短连接、长轮询和WebSocket
查看>>
day30 模拟ssh远程执行命令
查看>>