?

基于混合域注意力機制的服裝關鍵點定位及屬性預測算法

2022-09-19 01:29雷冬冬王俊英董方敏臧兆祥聶雄鋒
關鍵詞:關鍵點卷積注意力

雷冬冬,王俊英,董方敏,臧兆祥,聶雄鋒

(三峽大學 a.水電工程智能視覺監測湖北省重點實驗室;b.湖北省建筑質量檢測裝備工程技術研究中心,湖北 宜昌 443002)

近年來服裝視覺的應用日益廣泛,如模擬試衣間、同款搜圖、換裝游戲等,具有較大的潛在應用價值。實際應用中服裝視覺算法面臨各種挑戰,如由模特姿勢引起的服裝變形和服裝遮擋,服裝款式、材質和剪裁上的差異,以及同款服裝在“買家秀”和“賣家秀”中的差異等。神經網絡作為解決視覺分析領域問題的重要方法之一,得到了廣大研究人員的青睞。

神經網絡在誕生之初吸收了生物學的原理本質,并在后續發展中脫離了生物細節,使用更加講究效率的數理工科思維,從而取得成功。研究者們基于神經網絡所做的服裝視覺分析工作[1-7]也取得了顯著成效,主要體現在服裝關鍵點檢測、服裝檢索和服裝的屬性預測等方面?;谧藨B估計的方法[8-9]通過對服裝姿態進行估計消除了服裝姿態對服裝關鍵點檢測的影響?;诩s束的方法[3,10-11]在算法模型中加入語義約束,利用布局約束或空間關系等語義約束提高服裝關鍵點檢測的性能?;谧⒁饬C制的方法[12-15]識別圖像的不同成分,使神經網絡能夠在解決如服裝檢測、檢索、姿態估計等特定問題時應更多關注圖像中的哪些特征。

考慮到服裝的非剛性變形較大,不同模特姿態和服裝風格下服裝的關鍵點存在較大的空間差異,本文提出一種基于混合域注意力(mixted domain attention, MDA)機制的服裝關鍵點定位及屬性預測算法,利用循環十字交叉注意力(recurrent criss-cross attention,RCCA)[16]模塊獲取服裝關鍵點之間潛在的空間關系,通過高效通道注意力(effective channel attention,ECA)[17]模塊獲得通道之間的交互信息,以期優化算法模型的性能,提高服裝關鍵點定位、服裝分類以及屬性預測效果。

1 服裝關鍵點定位及屬性預測算法

基于混合域注意力機制的服裝關鍵點定位及屬性預測算法(MDA-DFA)是在Deep Fashion Analysis(DFA)[7]算法的基礎上,引入RCCA算法和ECA算法來融合空間域和通道域注意力機制,以便更好地提取服裝特征,最終提高服裝關鍵點定位和屬性預測效果。

1.1 DFA算法

DFA算法[7]主要是基于VGG-16網絡,如圖1所示。該算法將原始圖像的大小調整為224像素×224像素,采取與VGG-16網絡相同的初始卷積操作,在Conv4_3層后利用連續的卷積和轉置卷積操作生成關鍵點熱圖進行定位。關鍵點熱圖特征和Conv4_3卷積特征共同組合成新的注意力映射,使得DFA網絡可根據局部關鍵點和全局特征更靈活地聚焦服裝的重要功能部分。

圖1 DFA網絡架構Fig.1 Network architecture of the DFA

DFA算法利用轉置卷積對特征圖進行上采樣,獲得的關鍵點熱圖在具有高分辨率的同時未丟失信息,與輸入的服裝圖像具有相同的尺寸,可提高服裝關鍵點定位的準確性。其以關鍵點熱圖為基礎產生統一的空間注意力機制,使網絡具有足夠的信息去增強或減弱特征,避免了特征選擇中的硬確定性約束,可取得較好的分類和屬性預測效果。

1.2 全局信息模塊

因服裝關鍵點之間存在潛在的空間聯系,為獲得服裝圖像的全局特征,利用非局部空間連接算法中的RCCA算法獲取特征的全局聯系,從而捕獲關鍵點之間的空間關系。RCCA算法是將Criss-Cross Attention(CCA)重復操作R次,通過計算任意兩個位置間的交互直接捕捉遠程的上下文信息,而不局限于相鄰的點,相當于構造了1個和圖像尺寸相同的卷積核,因此可以獲得全局信息。

CCA[16]的運算過程如圖2所示。給定局部特征映射F∈C×W×H,對F分別應用1個帶有1×1濾波器的卷積層后,得到兩個特征映射Q和K,其中{Q,K}∈C′×W×H,C′為降維后的通道數。

圖2 CCA算法的細節Fig.2 The details of the CCA algorithm

得到Q和K后,通過Affinity運算和歸一化處理進一步生成注意力映射圖A∈(H+W-1)×(W×H)。在Q的空間維度的每個位置u都可以得到一個向量Qu∈C′。通過從與位置u在同一行或同一列的K中提取特征向量獲得集合Ωu∈(H+W-1)×C′。Ωi,u∈是Ωu的第i個元素。Affinity運算的定義如式(1)所示。

(1)

式中:di,u為特征Qu和Ωi,u的關聯度,di,u∈D,i的取值范圍為1到H+W-1的整數,D∈(H+W-1)×(W×H)。在D的通道維度上應用Softmax層計算注意力映射圖A。

在F上應用另一個帶有1×1的濾波器的卷積層生成V∈C×W×H用于特征自適應[16]。在V的空間維度的每個位置u,都能得到1個向量Vu∈C和1個集合Φu∈(H+W-1)×C。集合Φu是V中與位置u同行或同列的特征向量的集合。上下文信息由式(2)定義的Aggregation運算收集。

(2)

RCCA算法模塊首先將局部特征映射F輸入到CCA模塊中,聚集十字交叉路徑中的每個像素的上下文信息生成1個新的特征映射F′,則特征映射F′同時包含水平和垂直方向上的上下文信息。為獲得更豐富、密集的上下文信息,將特征映射F′再次輸入到CCA模塊中,并輸出特征映射F″。特征映射F″中的每個位置實際上收集了服裝圖像上所有像素的信息,捕獲了長依賴關系。前后兩個CCA模塊可共享相同的參數,避免增加額外的成本。

1.3 通道注意力模塊

DFA算法采取均分權重的方法處理通道域中的圖像特征信息,然而,各通道域中的圖像特征信息對分類和屬性預測的影響是各不相同的。因此,在提出的MDA-DFA算法中引入通道域注意力機制。通道域注意力機制的原理為通過建立不同通道之間的相關性,基于網絡學習的方式自動獲取每個特征通道的重要程度,據此賦予每個通道不同的權重系數,從而強化重要的特征并抑制不重要的特征。具體操作:在RCCA算法和關鍵點注意操作后,網絡的通道注意模塊根據通道對服裝分類和屬性預測任務貢獻的程度為512個通道分配權重,然后將其與原始特征映射相乘,得到加權的服裝特征映射。

現有的通道注意力方法多致力于開發復雜的網絡模塊以實現更好的性能,因此不可避免地增加了模型的復雜性。而ECA網絡在全局平均池化操作后,利用大小為k的一維卷積考量各通道及其k個鄰居,從而在不降維的情況下捕獲局部跨通道交互信息,克服了網絡模型的性能與復雜性之間的矛盾。給定不降維的聚合特征y∈C,通過式(3)學習通道注意力。

ω=σ(Wy)

(3)

式中:W為一個C×C的參數矩陣,使用帶狀矩陣即式(4)學習通道注意力。

(4)

對式(4)來說,計算yi的權重時只需考慮yi和它的k個鄰居之間的相互作用,如式(5)所示。

(5)

式中:Ωi,k為yi的k個相鄰的通道的集合。

為了進一步提高性能,利用式(6)所示的方法使所有通道共享相同的學習參數。

(6)

這種讓所有通道共享相同學習參數的方法可以通過式(7)實現。

ω=σ(C1Dk(y))

(7)

式中:C1Dk表示卷積核大小為k的一維卷積。使用這種跨通道交互的方法只涉及k個參數,在模型復雜度較低的情況下保證了ECA模塊的效率和性能。

給定通道數C,根據式(8)確定內核大小k。

(8)

式中:|t|odd表示最接近t的奇數。將通道數代入式(8)計算得到k=5。

鑒于空間域注意力機制和通道域注意力機制在圖像特征提取方面的優勢,提出混合域注意力機制模型以充分利用兩個注意力的信息,從而獲得更好的服裝關鍵點定位和屬性預測效果。

1.4 算法結構

MDA-DFA算法的整體架構主要分為5個階段,如圖3所示。

圖3 MDA-DFA算法的網絡架構Fig.3 The network architecture of MDA-DFA algorithm

階段1:利用VGG-16的前4層網絡提取原始服裝圖像的特征映射。

階段2:將階段1的輸出,通過RCCA模塊建立特征的全局聯系,預測服裝的關鍵點位置。

階段3:基于階段2的特征生成服裝關鍵點熱圖,將階段1得到的特征映射與熱圖進行通道拼接,再輸入到空間注意力網絡,得到加強服裝關鍵點信息后的特征映射。

階段4:將階段3得到的特征映射與初始特征進行融合,再通過ECA網絡建模卷積特征各通道之間的作用關系,從而改善網絡模型的表達能力,更好地獲取服裝特征。

階段5:將階段4獲得的特征送入VGG-16第5層及之后的網絡,再分別對服裝圖像進行分類和屬性預測,得到相應的結果。

2 試驗設計與結果分析

2.1 試驗平臺和數據集

試驗采用的平臺配置為Inter i7 CPU,GTX1080GPU,16 GB內存;軟件為Ubuntu操作系統,Python 3.6語言在Pytorch框架下實現。

試驗采用的數據集為當下權威的服裝評測數據集之一,香港中文大學多媒體實驗室開源的大型服裝數據集Deep Fashion[6]。該數據集含有非常豐富的標注信息,包括服裝主體bounding box、服裝類別、1 000種屬性(細節特征)、8個服裝關鍵點;數據中有正常、中等、嚴重等不同程度的變形圖片;服裝圖片的視角按照人體穿著分為上半身、下半身、全身,其服裝關鍵點個數分別為6、4、8。

2.2 試驗設計及步驟

給定一個服裝圖像I,目標是預測服裝關鍵點的位置L(見式(9))、服裝類別B以及服裝屬性向量A。在DeepFashion數據集中,所有服裝被分為50類,類別標簽滿足0≤B≤49。服裝分類預測可視為一個1-of-k(啞編碼)的分類問題,如“14”表示斗篷(Poncho),屬于上身衣服,“25”表示牛仔褲(Jeans),屬于下身衣服,“40”表示連衣裙(Dress),屬于全身衣服;屬性預測為多標簽分類問題,標簽向量A=(a1,a2,…,an),其中n為屬性總數,ai∈{0,1},ai=1表示服裝圖像具有第i個屬性,反之則不具有。

L={(x1,y1),(x2,y2),…,(xnl,ynl)}

(9)

式中:xi和yi是每個關鍵點的坐標,nl為關鍵點的總數。試驗采用的數據集標注是8個關鍵點,故nl=8。

設計5組對比試驗以改進服裝關鍵點定位、分類與屬性預測效果。

試驗1:基于DFA網絡架構,在第1次卷積和轉置卷積階段后加入RCCA網絡模塊,以克服多次卷積對圖像像素之間上下文信息提取不足的局限性。

試驗2:優化網絡結構,將RCCA網絡模塊移至所有的卷積和轉置卷積操作之前,以保留更多的原始全局信息。

試驗3:基于DFA算法,利用試驗2的設計獲取圖像像素之間的上下文信息,將關鍵點之間的內在聯系用于服裝分類與屬性預測。

試驗4:在試驗3的基礎上,在空間注意力網絡之后加入ECA網絡模塊,讓網絡學習通道之間的交互信息。

試驗5:MDA-DFA算法。將試驗3和4的設計結合起來,在網絡架構引入RCCA和ECA模塊,融合空間域和通道域注意力機制,以更好地提取服裝特征。

試驗步驟:

(1)讀取數據集及初始設置。讀取數據集的路徑、圖像標注等相關信息;設置批次、批大小、學習率等初始參數。初始學習率設置為0.000 1,并以0.9的線性衰減率衰減。整個模型訓練10個回合,訓練的批處理大小為16。

(2)創建info.csv文件。將數據集中所有的標注信息整合到文件info.csv中,以滿足機器學習的要求。其中每一行代表一張圖片,包含圖片ID、類別、關鍵點位置、屬性、bounding box等信息。

(3)數據集預處理。定義服裝圖片的增廣函數,進行翻轉、隨機裁剪、中心裁剪、隨機翻轉等多種預處理操作,增強數據的穩健性,處理后圖片的尺寸為224像素×224像素。

(4)搭建基礎網絡架構。首先定義整個VGG-16的網絡架構,然后定義高斯核函數和損失函數。

(5)搭建關鍵點定位的網絡。首先定義關鍵點上采樣函數,然后定義關鍵點提取函數和訓練網絡。

(6)搭建RCCA網絡。定義實現RCCA的函數。

(7)搭建ECA網絡。定義實現ECA的函數。

(8)提取初始特征。通過在ImageNet數據集上加載了VGG-16的預訓練模型對本文的模型參數進行初始化。

(9)訓練數據。設置相關的參數進行訓練,每10個step顯示一次計算出的損失值。

(10)預測關鍵點。訓練結束后,根據訓練好的模型對測試集進行預測。將關鍵點的熱圖記為M′∈R224×224×8,添加高斯濾波器,對關鍵點熱圖進行可視化處理。關鍵點定位采用均方誤差(MSE)的損失函數,如式(10)所示。

(10)

式中:N為數組元素的總數,i,j∈(0,224)。

(11)預測類別與屬性。訓練結束后,根據訓練好的模型對測試集進行預測。分別使用兩個全連接層預測服裝圖像的類別和屬性,它們的損失函數都是標準的交叉熵損失,如式(11)所示。

(11)

式中:X[true]為樣本真實標簽的得分;X[j]為第j個類別的得分。

(12)觀察模型損失。利用TensorBoard可視化工具實時觀察和記錄損失值和預測的結果。

2.3 試驗結果與分析

采用常用指標歸一化誤差(Enormalized)衡量圖像關鍵點定位算法的性能,計算方法如式(12)所示。

(12)

采用準確率(Raccuracy)和召回率(Rrecall)兩種標準試驗評測指標客觀分析算法模型在服裝數據集上的表現。準確率是指預測正確的結果占總樣本的百分比,如式(13)所示。召回率又叫查全率,是指在實際為正的樣本中被預測為正樣本的概率,如式(14)所示。

(13)

(14)

式中:Ntp為模型預測正確的正樣本數量;Ntn為模型預測正確的負樣本數量;Nfp為模型預測錯誤的正樣本數量;Nfn為模型預測錯誤的負樣本數量。

將試驗1和試驗2的關鍵點定位歸一化誤差與在相同數據集下的其他方法(如Fashion Net[6]、Deep Alignment[2]、DLAN[3]、DFA[7]和文獻[15])的試驗結果進行對比,如表1所示。

表1 不同算法的關鍵點定位歸一化誤差Table 1 Normalized location error of key point of different algorithms %

由表1可知,試驗1除了左袖口和下擺處的關鍵點定位誤差略大于DFA算法,下擺處的定位誤差略大于文獻[15]以外,其他關鍵點定位誤差均小于所有對比方法的試驗數據,體現了本文算法的有效性及其優勢。將改進后試驗2的結果與DFA算法進行對比可知,試驗2中算法的定位誤差在右領口處減少0.07個百分點,右袖口處減少0.16個百分點,左腰線處減少0.05個百分點,右腰線處減少0.04個百分點,左下擺處減少0.02個百分點。說明試驗2體現出更佳的性能。

試驗5的服裝屬性預測、服裝分類以及關鍵點定位的損失曲線如圖4所示。

圖4 服裝屬性預測、分類和關鍵點定位的損失曲線Fig.4 Loss curves for apparel attribute prediction, classification and key point location

由圖4可知,在訓練之初,模型的效果并不好,損失值較大。隨著迭代次數的增加,網絡所得誤差值通過反向傳播求解梯度,并通過梯度下降的方式更新模型參數,訓練的誤差值才逐漸降低。當誤差值降到一定閾值時,模型收斂,則訓練停止。

圖5為服裝的關鍵點定位可視化結果。

圖5 原圖和相應的關鍵點熱圖Fig.5 Original drawing and corresponding key point heat map

將試驗3~5的試驗結果與在相同數據集下其他方法,如WTBI[18]、DARN[19]、Fashion Net[6]、Weakly[20]、文獻[11]、DFA[7]和文獻[15]的試驗數據進行對比,如表2所示。

由表2可知,MDA-DFA算法優于試驗3和4的試驗結果,在top-3的分類結果中,同時融合空間聯系RCCA和高效通道注意力ECA的MDA-DFA網絡得到的準確率最高,為91.36%,相比改進前的DFA網絡提高了0.2個百分點。在屬性預測的結果中,MDA-DFA算法的召回率也更高,總體上表現更佳,其中top-5面料預測的召回率比DFA網絡提高了0.59個百分點。由此可見,將RCCA和ECA結合起來使用時網絡的性能更優。

因此,提出的基于混合域注意力機制的服裝關鍵點定位與屬性預測算法能有效提高對服裝袖口和腰線處關鍵點定位的精度,對困難關鍵點的定位有比較明顯的改進作用,并在一定程度上提高了服裝的分類與屬性預測效果。

表2 不同算法的服裝分類準確率與屬性預測的召回率Table 2 Clothing classification accuracy and attribute prediction recall rate of different algorithms %

3 結 語

提出一個基于混合域注意力機制的服裝關鍵點定位及屬性預測的算法,利用RCCA模塊獲取服裝圖像像素的上下文信息,從而捕獲關鍵點之間的空間聯系,利用局部跨通道交互策略生成通道注意力捕獲卷積通道間的交互信息,并將兩種注意力分支網絡得到的特征融合后再進行分類和屬性預測。結果表明該算法取得了不錯的效果。但相比人類對于服裝的理解,人工智能還差得很遠。在今后的研究中,可嘗試將神經進化算法等生物學策略應用到相關領域,以促進其在計算機視覺中的應用。

猜你喜歡
關鍵點卷積注意力
基于全卷積神經網絡的豬背膘厚快速準確測定
論建筑工程管理關鍵點
基于圖像處理與卷積神經網絡的零件識別
讓注意力“飛”回來
水利水電工程施工質量控制的關鍵點
一種基于卷積神經網絡的地磁基準圖構建方法
基于3D-Winograd的快速卷積算法設計及FPGA實現
利用定義法破解關鍵點
A Beautiful Way Of Looking At Things
機械能守恒定律應用的關鍵點
91香蕉高清国产线观看免费-97夜夜澡人人爽人人喊a-99久久久无码国产精品9-国产亚洲日韩欧美综合