详解文本分类之DeepCNN的理论与实践
导读
最近在梳理文本分类的各个神经网络算法,特地一个来总结下。下面目录中多通道卷积已经讲过了,下面是链接,没看的可以瞅瞅。我会一个一个的讲解各个算法的理论与实践。目录暂定为:
多通道卷积神经网络(multi_channel_CNN)
深度卷积神经网络(deep_CNN)
基于字符的卷积神经网络(Char_CNN)
循环与卷积神经网络并用网络(LSTM_CNN)
树状LSTM神经网络(Tree-LSTM)
Transformer(目前常用于NMT)
etc..
之后的以后再补充。今天我们该将第二个,深度卷积神经网络(DeepCNN)。
DeepCNN
DeepCNN即是深度卷积神经网络,就是有大于1层的卷积网络,也可以说是多层卷积网络(Multi_Layer_CNN,咳咳,我就是这么命名滴!)我们来直接上图,看看具体长得啥样子:
我大概描述下这个过程,比如sent_len=10,embed_dim=100,也就是输入的矩阵为(10*100),假设kernel num=n,用了上下padding,kernel size=(3*100),那么卷积之后输出的矩阵为(n*10),接着再将该矩阵放入下个卷积中,放之前我们先对这个矩阵做个转置,你肯定要问为什么?俺来告诉你我自己的认识,有两点:
硬性要求:这个矩阵第一个维度为10是句子长度产生的,所以是变量,我们习惯将该维度的大小控制为定量,比如第一个输入的值就是(sent_len,embed_dim),embed_dim就为定量,不变。所以转置即可。
理论要求:(n*10)中的n处于的维度的数据表示的是上个数据kernel对这个数据的10个数据第一次计算,第二次计算... 第10次计算,也就可以表示为通过kernel对上个数据的每个词和它的上下文进行了新的特征提取。n则表示用n个kernel对上个句子提取了n次。则最终的矩阵为(n*10),我们要转成和输入的格式一样,将第二维度依然放上一个词的表示。所以转置即可。
n 可以设置100,200等。
然后对最终的结果进行pooling,cat,然后进过线性层映射到分类上,进过softmax上进行预测输出即可。
上述仅仅说的是两层CNN的搭建,当然你可以搭建很多层啦。
实践
下面看下具体的pytotch代码如何实现
类Multi_Layer_CNN的初始化
def __init__(self, opts, vocab, label_vocab): super(Multi_Layer_CNN, self).__init__() random.seed(opts.seed) torch.manual_seed(opts.seed) torch.cuda.manual_seed(opts.seed) self.embed_dim = opts.embed_size self.word_num = vocab.m_size self.pre_embed_path = opts.pre_embed_path self.string2id = vocab.string2id self.embed_uniform_init = opts.embed_uniform_init self.stride = opts.stride self.kernel_size = opts.kernel_size self.kernel_num = opts.kernel_num self.label_num = label_vocab.m_size self.embed_dropout = opts.embed_dropout self.fc_dropout = opts.fc_dropout self.embeddings = nn.Embedding(self.word_num, self.embed_dim) if opts.pre_embed_path != '': embedding = Embedding.load_predtrained_emb_zero(self.pre_embed_path, self.string2id) self.embeddings.weight.data.copy_(embedding) else: nn.init.uniform_(self.embeddings.weight.data, -self.embed_uniform_init, self.embed_uniform_init) # 2 convs self.convs1 = nn.ModuleList( [nn.Conv2d(1, self.embed_dim, (K, self.embed_dim), stride=self.stride, padding=(K // 2, 0)) for K in self.kernel_size]) self.convs2 = nn.ModuleList( [nn.Conv2d(1, self.kernel_num, (K, self.embed_dim), stride=self.stride, padding=(K // 2, 0)) for K in self.kernel_size]) in_fea = len(self.kernel_size)*self.kernel_num self.linear1 = nn.Linear(in_fea, in_fea // 2) self.linear2 = nn.Linear(in_fea // 2, self.label_num) self.embed_dropout = nn.Dropout(self.embed_dropout) self.fc_dropout = nn.Dropout(self.fc_dropout)
数据流动 def forward(self, input): out = self.embeddings(input) out = self.embed_dropout(out) # torch.Size([64, 39, 100]) l = [] out = out.unsqueeze(1) # torch.Size([64, 1, 39, 100]) for conv in self.convs1: l.append(torch.transpose(F.relu(conv(out)).squeeze(3), 1, 2)) # torch.Size([64, 39, 100]) out = l l = [] for conv, last_out in zip(self.convs2, out): l.append(F.relu(conv(last_out.unsqueeze(1))).squeeze(3)) # torch.Size([64, 100, 39]) out = l l = [] for i in out: l.append(F.max_pool1d(i, kernel_size=i.size(2)).squeeze(2)) # torch.Size([64, 100]) out = torch.cat(l, 1) # torch.Size([64, 300]) out = self.fc_dropout(out) out = self.linear1(out) out = self.linear2(F.relu(out)) return out
数据对比
可以看出多层(深层)CNN还是在有提升的。
原文发布时间为:2018-11-8
原文作者:zenRRan
本文来自云栖社区合作伙伴“深度学习自然语言处理”,了解相关信息可以关注“深度学习自然语言处理”。
低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。
持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。
转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。
- 上一篇
AI要是抢饭碗,第一个轮到谁?
前几天,人工智能被国家点名了! 来自国脉智慧城市网 文中提出要加强人工智能在“教育、医疗卫生、体育、住房、焦勇、助残养老、家政服务等领域的深度应用”。这意味着什么? 有一些行业的工作职位(尤其是中央点名的这些行业)会因AI的普及而消灭,与此同时新的就业机会会应运而生。 哪些工作会被消灭? 局部来看,流水线性质的职业被替代是一种趋势。 就像蒸汽机的发明替代了马夫,电脑的发明替代了打字员。互联网的发明(几乎)代替了邮递员。正如李开复老师在TED演讲上提到的趋势:未来15年内,流水线上的工作会被代替,即使一些律师医生也不例外。 观点节选之李开复老师在Ted上演讲:人工智能如何拯救人类 AI会带来更多就业? 根据历史的经验和经济学家的研究,技术进步总体上来讲并不会增加失业。恰恰相反,在技术进步快的时候,反而是增加就业的。 蒸汽机带来了火车和工厂,使工业效率成倍增加,带来无数无产阶级;电脑从大型变为PC,使工作效率不断增加,带来成千上万“上班族”;互联网从龟速到极速,使人们的联系范围不断扩大,带来的更多的“自由职业者”。AI的发展,必然带来更多“操作AI”的职业。 如何搭上这波AI就业的车? ●...
- 下一篇
汉语言处理包 HanLP v1.3.5,新功能、优化与维护
HanLP v1.3.5 更新内容: 大幅优化CRF分词和二阶HMM分词,重构CharacterBasedGenerativeModelSegment 自定义词典支持热更新:#563 ,ngram模型支持热加载:#580 新增一个提高用户词典优先级的开关:#633 支持98年人民日报的复合词语料格式,如"[中央/n 人民/n 广播/vn 电台/n]nt" 开放TextRank关键词提取中的最大迭代次数参数:#577 为Term添加equal方法 TextRankKeyword 提取窗口相近词的强化 文本摘要方法支持自定义句子分隔符 提高AC自动机健壮性,添加hasKeyword接口 修复BinTrie.remove不存在的key时导致的问题:#540 解决mini模型下同时打开所有命名实体识别和数词识别时触发的问题:#542 CharTable.txt 添加上下标字符的对应关系 将“t”等不可打印的字符视作分隔符:#584 中文数词与阿拉伯数词切分开 修正全角年份识别中字符串长度错误,修正数字识别工具的错误,增加测试代码。支持读取包含BOM的文本文件。 校对CoreNatureDict...
相关文章
文章评论
共有0条评论来说两句吧...