?

基于門控時空注意力的視頻幀預測模型

2024-01-22 10:55李衛軍張新勇高庾瀟顧建來劉錦彤
鄭州大學學報(工學版) 2024年1期
關鍵詞:時空注意力架構

李衛軍, 張新勇, 高庾瀟, 顧建來, 劉錦彤

(1.北方民族大學 計算機科學與工程學院,寧夏 銀川 750021;2.北方民族大學 圖像圖形智能處理國家民委重點實驗室,寧夏 銀川 750021)

近年來,隨著科技的飛速發展,智能設備得到了廣泛的普及,由此產生了海量的無標簽視頻數據。智能預測與決策系統在生活中具有重要的地位,視頻幀預測作為智能預測的關鍵技術,能夠為決策系統提供支持,在氣象預警[1]、交通流量[2]等領域具有廣泛的應用前景。

目前,視頻幀預測模型的多幀預測能力不足,其復雜的時空結構導致視頻幀預測仍然是一項非常具有挑戰性的任務?,F有的視頻幀預測方法可以分為兩類,主要包括單進單出預測架構和多進多出預測架構。其中,單進單出預測架構是視頻幀預測的主流結構。Srivastava等[3]通過編碼器將視頻序列重建為固定長度的特征向量,并輸入到長短期記憶網絡(long short term memory,LSTM)中進行多幀預測。為提高LSTM的特征捕捉能力,Shi等[4]采用卷積結構對LSTM的狀態轉移函數進行了擴展。為增強不同層次循環網絡間的聯系,Wang等[5]通過在自底向上和自頂向下的方向上建立記憶流,使模型能夠同時對短期變化和長期動態趨勢進行建模。在此基礎上,Wang等[6]建立了一種基于因果LSTM的循環網絡,由級聯的雙存儲器和梯度高速單元組成,能夠自適應地捕獲短期和長期依賴關系。上述方法能夠有效增強模型的特征學習能力,但隨著預測長度的增加會存在誤差累積的問題,導致預測精度迅速下降。

隨著神經網絡結構的快速發展,多進多出預測架構能夠有效避免在長期預測中受到的誤差累積影響。Liu等[7]采用3D卷積自編碼器學習體素流,并通過現有的流動像素值來合成未來視頻幀。Aigner等[8]提出一種基于時空三維卷積的生成式對抗網絡(genertive adversarial network,GAN),該架構能夠一次預測多個未來幀。Ye等[9]分別對空間特征和時間特征進行建模,并采用對抗損失函數來提高預測清晰度。對抗網絡和3D卷積的引入雖然能夠有效提高預測性能,但也導致模型變得更加復雜。

為了平衡模型的綜合性能,Gao等[10]提出了一種簡單的視頻預測模型(simple video prediction,SimVP),通過采用簡單的組成結構和訓練策略,以有效減少模型的參數量和訓練時間。但SimVP仍然存在兩個問題:①時空特征學習能力仍然不足;②難以平衡空間特征及時間特征的捕捉能力,導致對時間維度的信息學習不充分。受圖像分割[11]領域最新進展的啟發,本文提出了門控時空注意力。其中,空間注意力關注幀內空間位置下的相互關系,時間(通道)注意力[12]則關注幀間的變化趨勢,并采用門控機制來融合獲得的時間特征和空間信息。

1 相關工作

1.1 基于循環神經網絡的單進單出預測架構

目前,基于循環神經網絡的單進單出預測架構被廣泛用于處理序列數據。Wang等[13]利用相鄰隱藏狀態之間的差異信息對時空動力學中的非平穩和近似平穩特性進行建模。從預測編碼的角度,Lotter等[14]將真實信號和預測信號之間的差異信息作為網絡參數的更新指標。此外,受偏微分方程(PDEs)的啟發,Guen等[15]提出了物理動力學網絡(physical dynamics network,PhyDNet),采用雙分支架構來分離視頻中的物理動力學和未知因素。然而,該模型難以平衡長期和短期的預測性能。因此,Pan等[16]提出了基于特征分離原理的泰勒網絡(Taylor network,TaylorNet),該架構采用泰勒級數對視頻序列進行建模,有效提高了模型的多幀預測能力。上述方法通常采用堆疊各種特征學習模塊來提高預測效果,導致模型的計算量和參數量過大,這限制了模型的進一步廣泛應用。

1.2 基于卷積神經網絡的多進多出預測架構

近年來,基于卷積神經網絡的多進多出預測架構開始被應用在視頻幀預測領域中。Sun等[17]提出了一種新的U-net預測架構,能夠對神經網絡不同層次中的多個時間和空間尺度進行統一建模。受Transformer在計算機視覺領域成功應用的啟發,Ning等[18]提出了一種基于局部時空塊擴展的Transformer預測架構,通過將二維卷積融合到多頭注意力中以捕捉序列中的長期依賴關系。此外,Tan等[19]提出了一種輕量型時空預測學習框架,采用膨脹卷積構建時空注意力來增強模型的特征捕捉能力。多進多出預測架構通常構建各種模塊來增強空間特征的獲取能力,但對時間特征的學習仍然不足。

本文受SimVP框架的啟發,構建了基于門控時空注意力的視頻幀預測模型。通過多尺度深度條形卷積和通道注意力來捕捉復雜的時空運動趨勢,同時采用門控機制來平衡模型的時空特征學習能力,有效地增強了模型的時空動力學建模能力。

2 本文算法

2.1 問題描述

定義一個X={xt+1,xt+2, …,xt+m}表示長度為m的輸入視頻幀序列,Y={yt+1,yt+2,…,yt+n}表示待預測的未來n幀真實序列,Y′={y′t+1,y′t+2,…,y′t+n}表示模型預測的未來n幀視頻序列,其中xt,yt和y′t分別表示第t時刻的原始幀、真實幀和預測幀。模型訓練的目的就是通過輸入的視頻序列X來預測未來的視頻序列Y′,同時對模型的可學習參數Θ進行優化,使真實序列Y和預測序列Y′之間的差異最?。?/p>

Θ*=argminL(FΘ(X),Y)。

(1)

式中:Θ*為模型的最佳參數;FΘ為神經網絡模型;L為評估差異的MSE損失函數。

2.2 網絡結構

目前,在未來幀預測任務中領先的方法是SimVP架構,本文方法采用了類似的設計思想。如圖1所示,模型主要由空間編碼器、時空預測模塊和空間解碼器組成??臻g編碼器通過多層2D卷積來實現特征提取和下采樣操作,該模塊能夠將輸入的幀序列編碼到低維潛在空間。時空預測模塊主要由多個堆疊的門控時空網絡(MST)構成,MST通過對輸入的低維特征信息進行時空動力學建模,以學習視頻序列中的時間趨勢和空間相關性。此外,MST之間共享參數,這有效地減少了模型的參數量??臻g解碼器由2D卷積和上采樣操作組成,通過將時空預測模塊的輸出作為解碼器的輸入,以實現低維信息向真實預測幀的轉換,并且得到的預測序列可繼續作為模型的輸入進行后續的長期預測。

2.3 空間編碼器

如圖1所示,綜合考慮模型的計算量和參數量,空間編碼器采用了多層純卷積結構,主要由Conv2d、GroupNorm、SiLU組成。由于需要充分捕捉視頻幀的空間特征,并避免在下采樣過程中造成過多的信息損耗,本文在編碼器和解碼器之間建立了殘差連接,最大限度保留視頻幀的背景語義Bbn??臻g編碼器提取視頻序列高級特征信息的過程可以表示為

Zen,Bbn=σ(Norm2d(Conv2d(Xn)))。

(2)

式中:σ為激活函數SiLU;Norm2d為組歸一化層;Xn為輸入序列;Conv2d為2D卷積運算符;Zen為獲取的低維信息。通過將2D卷積的步長(step)設置為2實現下采樣,而設置為1則進行卷積操作。

2.4 時空預測模塊

時空預測模塊位于整個模型的中間部分。同空間編碼器和空間解碼器對單幀圖像進行操作不同,預測模塊處理沿時間維度堆疊形成的視頻幀序列。由于視頻幀預測是一種像素密集型任務,預測輸出和輸入的視頻幀分辨率相同,因此,預測模塊即要高效提取時空特征,又要盡可能避免預測過程中增大感受野導致的細節缺失。因此,本文提出了一種新的門控時空網絡(MST),如圖2所示。MST是一種基于Transformer的變體,由歸一化層(Batch Norm)、門控時空注意力層和全連接層組成。其中,門控時空注意力層主要包括空間注意力、時間注意力和門控融合機制3個部分,空間注意力能夠學習幀內的多尺度特征信息,而時間注意力能夠捕捉幀間的時間變化趨勢。此外,門控融合機制能夠有效地融合空間信息和時間特征,使模型能夠采取相同的重視程度來學習序列中的空間相關性和時間趨勢。門控時空注意力對視頻序列中每個時空位置下的運動強度進行合理的權重分配,這有效平衡了時間特征及空間信息的捕捉能力,同時能夠有效提高模型的時空預測建模能力。

圖2 MST網絡結構圖Figure 2 MST network structure diagram

2.4.1 時空注意力

為了有效捕捉空間相關性和時間依賴關系,注意力機制需要分解為空間注意力和時間注意力,以充分學習幀內和幀間的相互作用。由于傳統空間注意力的特征捕捉能力不足,并忽略了多尺度感受域的重要性,因此,本文采用多尺度深度條形卷積來構建空間注意力,同時使用大卷積核來增強模型的特征捕捉能力。如圖3所示,空間注意力獲取特征信息的過程主要包括2個階段:首先建立基于大卷積核的多尺度深度條形卷積Cdw1×k和Cdwk×1,以提取視頻序列Zi中的多尺度特征信息;然后通過大小為1×1的卷積核Conv2d1×1來聚合捕捉到的多尺度信息Zm??臻g注意力捕捉多尺度特征信息的過程可以表示為

Zm=∑k∈{7,11,21}Cdwk×1(Cdw1×k(Zi));

(3)

Zh=Conv2d1×1(Zm)。

(4)

式中:k為卷積核大小,k∈{7,11,21}代表k分別取7、11和21;Zh為聚合后的多尺度信息。

圖3 門控時空注意力網路結構Figure 3 Structure of the gated spatio-temporal attention network

空間注意力能夠有效捕捉幀內的空間相關性,但難以完整學習幀間的時間變化趨勢。因此,本文采用通道注意力作為時間注意力,利用通道間的相互關系獲取時間權重Sa。該過程可以表示為

Sa=FC(Avgpool(Zi))。

(5)

式中:Zi為原始輸入信息;Avgpool為全局平均池化;FC為全連接層。

2.4.2 門控融合機制

為了使模型對空間特征和時間特征采取相同的重視程度,本文提出了門控融合機制對空間注意力和時間注意力進行深度融合。如圖3所示,門控融合過程可以分為3個階段:首先,通過拆分操作split將通道數為2C的多尺度空間信息Zh拆分為通道數為C的空間特征Gs和Zt;其次,將空間信息Zt同時間權重Sa相乘,并通過激活函數Sigmoid將其映射至[0,1]以獲得時空權重;最后,將空間特征Gs乘以時空權重以獲得多尺度時空特征Z″i。整個注意力的融合過程可以表示為

Gs,Zt=split(Zh);

(6)

Z″i=σ(Sa?Zt)⊙(Gs)。

(7)

式中:σ為激活函數Sigmoid;⊙為哈達瑪積(Had-amard product);?為克羅內克積(Kronecker)。

2.5 空間解碼器

如圖1所示,空間解碼器由Conv2d、GroupNo-rm、PixelShuffle組成,通過將預測模塊輸出的預測信息輸入到空間解碼器中,能夠將低維預測信息Zc解碼為圖像序列Y′,同時補充背景語義Bbn??臻g解碼器輸出預測圖像序列的過程可以表示為

Y′=σ(Norm2d(Conv2d(Zc,Bbn)))。

(8)

式中:σ為激活函數SiLU;Conv2d為2D卷積,通過像素重組層(PixelShuffle)實現上采樣操作,否則進行步長為1的卷積操作。

3 實驗結果及分析

3.1 實驗環境及模型參數

本文采用的軟件運行平臺為Windows10專業版64位,深度學習環境軟件配置為Python3.8和PyTorch1.10。硬件配置為NVIDIA TITAN V顯卡,采用CUDA10.2,使用Adam優化器、OneCycle[20]及余弦退火學習率調整策略來訓練模型。

該模型的超參數主要包括學習率、訓練次數、drop_path、批處理大小、MST單元數等。其中,在Moving MNIST、TaxiBJ、WeatherBench和KITTI數據集上,學習率分別設置為0.001 0、0.000 5、0.005 0、0.005 0,訓練次數分別為600、50、50、100,而drop_path分別設置為0、0.2、0.2、0.2,批處理大小統一設置為16,MST單元數分別設置為8、8、8、6。

本文采用MSE損失函數來對模型進行訓練,并通過均方誤差(MSE)、平均絕對誤差(MAE)、結構相似指數(SSIM)和均方根誤差(RMSE)來評估預測圖像的質量。

3.2 實驗評估

本文在Moving MNIST[3]數據集上進行根據10個條件幀來預測10個未來幀的實驗,并同先進的循環式模型和多進多出預測方法對比來評估模型的時空預測學習能力。如表1所示,盡管沒有采用循環式設計,本文方法在Moving MNIST數據集上依然獲得了較高的預測精度,同SimVP相比,MSE和MAE分別降低了14.7%、8.9%,同時參數量和計算量也有所下降。雖然推理效率有所降低,但時空特征學習能力更強,這顯著地減少了模型的訓練次數,同時訓練時間縮短了近61 h。同最先進的循環式模型TaylorNet相比,本文模型雖然計算量有所增加,但MSE和MAE也分別降低了8.6%、3.7%,同時推理效率提高了12%,并顯著地縮短了訓練時間??梢钥闯?本文方法有效解決了循環式架構預測精度低、推理效率低和訓練時間長等問題。此外,同最先進的多進多出模型SimVP+gSTA相比,MSE和MAE也下降了9.0%、7.0%,在相同的訓練次數下,本文方法獲得了更高的預測精度和推理效率。

表1 在Moving MNIST數據集上的實驗結果Table 1 Experimental results on the Moving MNIST dataset

圖4所示為Moving MNIST數據集的預測結果,其中,誤差特征圖為真實幀和預測幀之間差值的絕對值??梢钥闯?隨著預測長度的增加,在t=10時,TaylorNet由于受到誤差累積的影響,產生了最密集的誤差圖。SimVP雖然解決了誤差累積的問題,但特征學習能力仍然不足,其誤差主要集中在圖像細節。而本文方法避免了誤差累積的影響,同時具有高效的特征學習能力,獲得了最佳的預測圖像。

圖4 Moving MNIST數據集預測結果Figure 4 Moving MNIST dataset prediction results

本文在TaxiBJ[22]數據集上同經典的基線模型和最新的先進方法對比來評估模型的交通流預測性能,如表2所示??梢钥闯?本文方法獲得了較高的預測精度,同最先進的循環式模型PredRNN相比,MSE和MAE分別降低了4.1%、2.6%,同時計算量減少了39.8 GFlops。因此基于端對端的多進多出預測架構顯著優于循環式單進單出預測架構,能夠有效增強模型的預測性能,并減少計算量。而同最先進的多進多出模型TAU相比,MSE也降低了1.3%,并且計算量僅略微增加。此外,SimVP是近期提出的一種簡單的多進多出純卷積網絡,該模型構造簡單,具有較高的綜合性能,本文方法同SimVP相比,在MSE和MAE上也分別降低了6.7%、3.2%,同時能夠顯著減少計算量。

表2 在TaxiBJ數據集上的實驗結果Table 2 Experimental results on the TaxiBJ dataset

圖5所示為TaxiBJ數據集的預測結果,可以看出,隨著預測長度的增加,在t=4時,循環式模型受到誤差累積的影響,導致MAU的預測效果迅速下降,SimVP雖獲得了不錯的預測效果,但對時間趨勢的捕捉能力仍然不足。本文方法能夠有效地平衡時間及空間特征的學習能力,取得了最佳的預測效果,具有很好的交通流預測性能。

圖5 TaxiBJ數據集預測結果Figure 5 TaxiBJ dataset prediction results

氣候預測是時空預測學習的另一項基本任務,本文在WeatherBench[24]數據集上同時空預測學習方法進行了對比試驗。如表3所示,循環式時空預測學習方法雖取得了一定效果,但復雜的結構也導致計算量過大,而本文方法采用多進多出預測架構實現了更好的綜合性能。其中,同最先進的循環式模型MAU相比,MSE降低了11%,并且計算量減小了32.6 GFlops。而同最先進的多進多出模型SimVP+gSTA相比,在MAE上也降低了0.9%。此外,同SimVP模型相比,MSE和MAE分別降低了10.5%、7.5%。

表3 在WeatherBench數據集上的實驗結果Table 3 Experimental results on the WeatherBench dataset

圖6 WeatherBench數據集預測結果Figure 6 WeatherBench dataset prediction results

圖6所示為WeatherBench數據集預測結果??梢钥闯?隨著預測長度的增加,在t=12時,SimVP模型難以完整地預測圖像細節,MAU由于預測機制的原因,在長期預測中精度會迅速下降。而本文方法獲得了最稀疏的誤差圖,高效的特征提取能力能夠學習到更多的圖像細節,并且不受誤差累積的影響,在全球氣候預測任務中表現出極佳的性能。

復雜的真實世界往往包含了不同運動對象的各種非線性時空運動,這導致時空預測學習更加具有挑戰性。為了評估模型的泛化能力和適應性,本文在KITTI[14]數據集上進行訓練,并在CalTech Pedestrian數據集[14]上進行最終測試。其中,模型在KITTI和Caltech Pedestrian上采用了相同的參數設置,統一進行通過10個條件幀來預測1個未來幀的對比實驗。

如表4所示,本文方法在真實數據集KITTI上獲得了較高的預測精度,同基線模型SimVP相比,MSE和MAE分別降低了18.5%、12.3%。而同最先進的循環式模型ConvLSTM相比,本文方法在MSE和MAE上也分別降低了6.4%、6.4%,同時計算量更小。此外,同最先進的多進多出模型SimVP+gSTA相比,雖然MSE略微有所上升,但MAE降低了1.7%,并且計算量減少了45.6 GFlops??梢钥闯?多進多出預測架構在預測精度上顯著優于循環式預測架構,而本文方法通過較少的計算量達到了和SimVP+gSTA模型同樣先進的預測性能,并且顯著優于其他時空預測學習方法,具有很好的自動駕駛預測能力。

表4 在KITTI數據集上的實驗結果Table 4 Experimental results on the KITTI dataset

3.3 消融擴展實驗

為分析門控時空注意力每個局部模塊對最終預測性能的影響,本文在TaxiBJ數據集上進行了消融實驗。表5所示為消融實驗結果,其中“No/MST”表示用1×1卷積替換門控時空注意力層,“No/Sat-3×3”和“No/Sat-7×7”分別是將空間注意力的多尺度深度卷積替換成3×3卷積和7×7卷積,“No/Tat”表示沒有設置時間注意力,“No/Mk”表示不采用門控融合機制平衡注意力。而“MST-4”、“MST-6”和“MST-10”則表示MST的數量分別設置為4、6和10。

如表5所示,采用門控時空注意力層使得MSE和MAE分別降低了11.4%和3.8%。同3×3卷積和7×7卷積相比,使用多尺度深度條形卷積能夠增強模型的感受野和捕捉多尺度特征的能力,使得MSE分別降低了3.7%、1.1%。通過時間注意力學習幀間的相互作用,使MSE也降低了1.8%。而門控機制深度融合了兩種注意力,MSE降低了1.6%??梢钥闯?模型中的每個模塊都能夠有效提高最終的預測精度。此外,設置過多的MST單元帶來的效果提升并不明顯,同時導致了模型的參數量和計算量增大。因此,本文將MST數量設置為8,并同上述3個模塊進行集成獲得了最佳的時空預測性能。

本文在TaxiBJ數據集上進行了卷積擴展實驗如表6所示。其中,Dw為本文采用的多尺度深度條形卷積,Dc代表使用多尺度膨脹卷積,Mm代表采用多尺度2D卷積,并在最終測試階段通過重參數融合法[25]壓縮模型,Mc為使用多尺度2D卷積,其中7×7卷積被3個3×3卷積所代替。同Dc和Mc相比,Dw在預測性能、參數量及推理效率方面具有顯著優勢,而Mm由于采用了重參數融合法,獲得了最佳的推理效率,但本文方法獲得了更高的預測精度,同時具有很好的推理效率。

表6 卷積擴展實驗對比結果Table 6 Convolution extension experiment comparison results

為了探究不同預測架構對收斂性能的影響,本文在Moving MNIST數據集上進行了擴展實驗。圖7所示為不同模型收斂速度的對比結果??梢钥闯?同單進單出預測架構PhyDNet相比,多進多出預測策略在收斂性能方面具有顯著優勢。其中,本文方法實現了比SimVP更快的收斂速度,獲得了較好的收斂效果。這表明,在每次訓練中,模型能夠捕捉到更多的時空動態趨勢,這將會有效縮短模型的整體訓練時間。

圖7 收斂性能實驗結果Figure 7 Convergence performance experimental results

4 應用前景展望

隨著計算機視覺和深度學習技術的不斷發展,視頻預測技術將會具有更加廣泛的應用前景。在交通領域中,視頻預測技術可用于交通流監測、交通事故預測和城市規劃,通過分析實時的視頻流,交通系統可以更好地調度交通信號、減少擁堵,有效提高交通系統的效率。在氣象領域中,視頻預測技術可用于監測自然災害,通過分析衛星和地面攝像頭的視頻數據,能夠提前發現災害跡象并發出預警提示,有效減少損失。視頻預測技術的發展將會產生很多新的應用領域,在醫療領域中,視頻預測技術將可以用于遠程患者的監測、手術中的實時病情分析,醫生可以利用視頻預測技術來提高手術的準確性和安全性。視頻預測技術將在多個領域引領創新和變革,將會有助于提高效率和安全性,并有潛力挖掘出更多的應用場景,為未來創造更多的可能性。

5 結論

本文提出了門控時空注意力來生成幀內和幀間相互關系的時空權重,以充分學習視頻序列中空間維度和時間維度下有意義的時空信息,并采用門控融合機制平衡空間及時間注意力的特征捕捉能力,在Moving MNIST、 TaxiBJ、WeatherBench、KITTI數據集上的實驗結果均優于對比算法。此外,現有方法并未充分考慮幀內的多尺度信息交互作用對預測精度的影響,在今后的工作中,將研究如何更加高效地捕捉幀內及幀間的信息交互關系,同時保持模型結構簡單、參數量低和推理效率高等優勢。

猜你喜歡
時空注意力架構
基于FPGA的RNN硬件加速架構
跨越時空的相遇
讓注意力“飛”回來
功能架構在電子電氣架構開發中的應用和實踐
鏡中的時空穿梭
玩一次時空大“穿越”
LSN DCI EVPN VxLAN組網架構研究及實現
“揚眼”APP:讓注意力“變現”
A Beautiful Way Of Looking At Things
時空之門
91香蕉高清国产线观看免费-97夜夜澡人人爽人人喊a-99久久久无码国产精品9-国产亚洲日韩欧美综合