?

基于卷積神經網絡和自注意力機制的文本分類模型

2020-06-03 07:57汪嘉偉楊煦晨琚生根謝正文
關鍵詞:長距離準確率注意力

汪嘉偉, 楊煦晨, 琚生根, 袁 宵, 謝正文

(四川大學計算機學院, 成都 610065)

1 引 言

文本分類為自由文本文檔分配預定義的類別,是自然語言處理領域的基礎性任務.文本分類的應用包括情感分析[1]、問題分類[2]、主題分類[3-5]等.卷積神經網絡(Convolutional Neural Network,CNN)[6-8]廣泛應用于文本分類.單詞級別的淺層CNN模型[6]使用預訓練的詞向量[9]作為輸入,利用多種具有不同過濾器的CNN抽取文本序列的局部特征,在文本分類任務上取得了良好的表現.由于模型的深度較淺(只有一層CNN),單詞級別的淺層CNN模型無法捕捉長距離依賴[10].文獻[11]詳細研究了CNN模型的深度對分類效果的影響,發現對于單詞級別的CNN模型,加深模型的層數并不能提高模型的準確率,反而導致模型準確率的下降.為了捕捉長距離依賴,本文引入自注意力機制.首先,文本序列中的每個單詞通過CNN得到一個上下文表示,自注意力機制通過計算所有單詞的上下文表示兩兩之間的相似度捕捉長距離依賴;然后,利用最大池化得到文本序列的最終表示;最后,將該表示送入全連接層得到分類結果.與單詞級別的淺層CNN模型比較,本文的模型在AGNews、DBPedia、Yelp Review Polarity、Yelp Review Full、Yahoo! Answers 5個公開的數據集上準確率得到了一致的提升.

2 相關工作

在文本分類上,有著大量的研究.傳統的方法使用線性模型[4]或支持向量機[12-13]根據手工構造的文本特征對文本進行分類.這些特征包括詞袋特征、n-gram特征、TF-IDF特征等.

近年來,隨著深度學習的發展,神經網絡模型廣泛應用于文本分類[1,6-8,14-16].文本分類任務的神經網絡模型主要分為三大類:基于RNN的模型、基于CNN的模型和基于注意力機制的模型.

RNN適用于處理序列輸入,因此,許多RNN的變種被應用于文本分類.文獻[14]利用LSTM建模序列,文獻[1]利用LSTM和門控RNN建模句子間的關系.文獻[15]利用層級GRU對文檔進行建模,并利用注意力機制捕獲文檔中重要的單詞信息和句子信息.文獻[16]將殘差連接[17]引入RNN,使模型能夠處理更長的序列.

CNN在計算機視覺領域獲得了巨大的成功[17-18],文獻[19]首次將CNN應用于自然語言處理任務.文獻[6]使用預訓練的詞向量[9]作為輸入,利用一層CNN捕捉文本序列的局部特征和位置信息.文獻[7]首次探索了字符級別的深層CNN(6層)分類模型.文獻[8]構建了一個字符級別的極深的CNN(29層)分類模型.

文獻[20]首次僅利用注意力機制解決自然語言處理任務,沒有使用任何RNN和CNN結構.文獻[21]將注意力機制應用于文本分類任務,與文獻[20]相同,沒有使用任何RNN和CNN結構.

基于RNN的分類模型受制于RNN的串行結構,無法在序列上并行計算.字符級別的深層CNN模型由于模型深度的急劇增加,導致模型的計算復雜度隨之上升,嚴重影響了模型在實踐中的應用.僅僅基于注意力機制的模型無法捕捉文本序列的局部特征.單詞級別的淺層CNN模型無法捕捉長距離依賴.本文結合CNN和自注意力機制,提出一種新的單詞級別的文本分類模型Word-CNN-Att. Word-CNN-Att使用CNN捕捉文檔的局部特征,利用自注意力機制捕捉長距離依賴.

3 模 型

模型的整體架構如圖1所示.模型由卷積層、自注意力層、池化層和全連接層組成.卷積層用于捕捉文本序列的局部特征和位置信息,自注意力層用于捕捉長距離依賴,池化層用于獲得文本序列的最終表示,全連接層用于最后的分類.

3.1 卷積層

卷積層用于提取輸入序列的局部特征和位置信息.xi∈Rd是一個d維的向量,表示輸入序列中的第i個單詞,一個長度為n的序列表示為:

x1:n=x1⊕x2⊕…⊕xn

(1)

其中,⊕是一個連接操作符;xi:i+j表示單詞xi,xi+1,…,xi+j的連接;過濾器w∈Rk×d對具有k個單詞的窗口進行卷積操作,產生新的特征.例如,特征

ci=f(w·xi-(k-1)/2,i+(k-1)/2+b)

(2)

其中,b∈R是偏置;f是非線性函數ReLU.對于超過序列邊界的索引,本文采用零填充.這個過濾器應用到每個可能的窗口,產生一個特征圖.

(3)

卷積層共有m個核寬為k的過濾器,對每個過濾器重復上述過程,并將得到的特征圖連接起來,得到:

Z=(z1,z2,…,zn)

(4)

Z∈Rn×m.如圖1所示,本文采用多個核寬分別為3、4、5的過濾器.

3.2 自注意力層

自注意力機制的核心在于點乘注意力[20],點乘注意力的計算過程如圖2所示,定義如下.

(5)

圖1 模型架構.使用3種不同的過濾器,分別具有核寬:3,4,5,每種過濾器有兩個Fig.1 Architecture of Model. 3 convolutional layers with respective kernel window sizes 3,4,5 are used, and each of which has 2 filters

令Z=(z1,z2,…,zn)為卷積層的輸出,即自注意力層的輸入,zi∈Rm.在自注意力機制中,Q,K,V都是同一向量的線性變換.因此,本文定義自注意力如下.

Self-Att(Z)=

Attention(ZWQ,ZWK,ZWV)=

(6)

其中,WQ,WK,WV∈Rm×m,WQ,WK,WV都是模型的參數,在模型訓練中學習得到.

自注意力機制通過計算整個序列中所有令牌兩兩之間的相似度捕捉任意距離的長距離依賴[20].與RNN不同,由于RNN下一時刻的輸入依賴于上一時刻的隱藏層狀態,所以捕捉距離為n的長距離依賴,RNN的時間復雜度為O(n).而自注意力機制可以并行計算任意兩兩令牌之間的相似度,捕捉距離為n的長距離依賴的時間復雜度為O(1).因此,自注意力機制有著非常好的并行性.

自注意力層的整體結構如圖3所示,與文獻[18]相同,本文引入殘差連接[17]和層歸一化[22].因此,自注意力層的輸出為

SelfAtt-Out=layernorm(Self-Att(Z)+Z)

(7)

圖2 點乘注意力

圖3 自注意力層結構Fig.3 Architecture of self-attention layer

3.3 池化層與全連接層

對于自注意力層的輸出,應用最大池化,每個特征圖得到一個最大值,將所有特征圖的最大值連接起來,得到輸入序列的最終表示.

g=(e1,e2,…,em)

(8)

最后,一個線性層將g映射成文本類別數目的維度.

y=Wyg+by

(9)

4 實 驗

4.1 任務和數據集

本文實驗采用5個大規模的文本分類數據集,這些數據集在文獻[7]中提出,包括4種分類任務:新聞分類、本體分類、情感分析和主題分類.數據集的具體情況如表1所示.表1中,“#Train”代表訓練集的樣例數目;“#Test”代表測試集的樣例數目;“#Classes”代表數據集的種類個數;“#Average Length”代表數據集中樣例的平均單詞數目.

表1 數據集的分布

4.2 實驗細節

本文使用NLTK對語料進行分詞,僅使用在訓練集中至少出現3次的單詞構建詞表.詞表中未出現的單詞使用一個特殊的令牌UNK代替.

本文使用斯坦福大學公開發行的Glove 300維詞向量[9]作為預訓練的詞向量.對于未出現在預訓練的詞向量中的單詞,本文使用從均勻分布(-0.1, 0.1)中采樣的300維向量作為其詞向量.

本文使用初始學習率為0.001的Adam優化算法[23].batch size設為64.對于每個數據集,實驗使用訓練集的10%作為驗證集.本文使用核寬為3、4、5的過濾器各100個.在模型的輸入層和線性層使用dropout[24],dropout的丟棄率為0.5.

4.3 實驗結果與分析

本文使用準確率作為評價指標,準確率越大模型效果越好.實驗結果如表2所示.

表2 模型準確率

表2的第2行至第4行展示了傳統的方法的準確率,bag of words[7]模型基于訓練集中頻率最高的50 000個單詞構建,ngrams[7]模型基于訓練集中頻率最高的500 000個n-grams構建,ngrams TFIDF[7]模型與ngrams[7]模型相同,但使用TFIDF作為特征.從表2可知,傳統的方法在AGNews、DBPedia、Yelp Review Polarity3個相對小的數據集上表現較好,在Yelp Review Full、Yahoo! Answers兩個相對大的數據集上表現較差.FastText[25]為文本分類模型提供了一個有力的基線.

char-CNN[7]、char-CRNN[26]、char-VDCNN[8]都是字符級別的CNN模型,將字符作為基本輸入單位.char-CNN使用了一個深層的CNN(6層).與char-CNN相比,本文的模型Word-CNN-Att在AGNews、DBPedia、Yelp Review Polarity、Yelp Review Full、Yahoo! Answers 5個數據集準確率分別提高了6.5%、0.6%、1.4%、3.0%、2.9%.char-CRNN模型利用CNN和RNN聯合學習文本特征,與char-CRNN模型相比,Word-CNN-Att在5個數據集準確率分別提高了2.3%、0.3%、1.6%、3.2%、2.4%.char-VDCNN構建了一個極深的CNN(29層),與char-VDCNN相比,Word-CNN-Att在5個數據集準確率分別提高了2.4%、0.2%、0.4%、0.3%、0.7%.可以看到,盡管char-VDCNN遠比Word-CNN-Att深,Word-CNN-Att在各個數據集上的準確率仍然均超過了char-VDCNN模型.由上述分析可知,字符級別的模型,無論是純粹的CNN模型或結合CNN和RNN的模型,盡管模型遠比Word-CNN-Att深,但表現均不如Word-CNN-Att. Word-CNN-Att是單詞級別的模型,可以有效地利用單詞的語義信息,而字符級別的模型無法利用單詞的語義信息.

Discriminative LSTM[14]是一個單詞級別的模型,利用LSTM作為特征提取器,將輸入序列中所有單詞的隱藏層狀態之和作為文本序列的最終表示.與Discriminative LSTM相比,Word-CNN-Att在5個數據集準確率分別提高了1.6%、0.2%、3.5%、5.4%、0.4%.與LSTM相比,CNN能夠有效地捕捉局部特征,但無法捕捉長距離依賴,而Word-CNN-Att利用自注意力機制捕捉長距離依賴.

Self-Attention[21]模型完全基于自注意力機制,沒有使用任何RNN和CNN結構,利用專門的位置向量編碼位置信息.與Self-Attention模型相比,Word-CNN-Att在5個數據集準確率分別提高了1.1%、0.2%、0.9%、1.0%、0.0%.Word-CNN-Att不僅使用自注意力機制捕捉長距離依賴,并且使用CNN學習文本序列的局部特征,而純粹的自注意力機制無法學習局部特征;與專門的位置向量相比,CNN能夠更有效地學習位置信息.

如表2所示,與單詞級別的一層CNN(Word-CNN)模型相比,Word-CNN-Att在5個數據集準確率分別提高了0.9%、0.2%、0.5%、2.1%、2.0%.實驗結果表明,自注意力機制有效地捕捉了長距離依賴,彌補了CNN無法捕捉長距離依賴的不足,提升了模型處理長文本分類任務的能力.表3展示了一個word-CNN-Att分類正確,而word-CNN分類錯誤的樣例.word-CNN或許抽取了一些關鍵的局部信息,比如:great、lovely,從而將該樣例錯誤分類為positive.而word-CNN-Att模型可以捕捉長距離依賴,因此可以捕捉到However之后的信息,所以將該樣例正確分類為negative.

表3 樣例分析

5 結 論

本文提出了一種結合CNN和自注意力機制的文本分類模型Word-CNN-Att.該模型利用CNN提取文本局部特征和位置信息,利用自注意力機制捕捉長距離依賴.在5個大型公開文本分類數據集上的實驗結果表明,Word-CNN-Att提升了單詞級別的淺層CNN模型的效果.在未來的研究中,計劃引入外部知識來進一步增強模型的文本分類能力.

猜你喜歡
長距離準確率注意力
讓注意力“飛”回來
乳腺超聲檢查診斷乳腺腫瘤的特異度及準確率分析
不同序列磁共振成像診斷脊柱損傷的臨床準確率比較探討
2015—2017 年寧夏各天氣預報參考產品質量檢驗分析
頸椎病患者使用X線平片和CT影像診斷的臨床準確率比照觀察
如何培養一年級學生的注意力
長距離PC Hi-Fi信號傳輸“神器” FIBBR Alpha
探討給排水長距離管道的頂管施工技術
A Beautiful Way Of Looking At Things
我國最長距離特高壓輸電工程開工
91香蕉高清国产线观看免费-97夜夜澡人人爽人人喊a-99久久久无码国产精品9-国产亚洲日韩欧美综合