?

基于條件生成對抗網絡的稀疏樣本回歸預測模型

2023-06-03 03:15薛嘉南孫學宏劉麗萍
關鍵詞:高斯標簽模態

薛嘉南 , 孫學宏 , 劉麗萍,3

(1.寧夏大學 物理與電子電氣工程學院,寧夏 銀川 750021; 2.寧夏大學 信息工程學院,寧夏 銀川 750021;3.寧夏沙漠信息智能感知重點實驗室,寧夏 銀川 750021)

回歸預測是一種根據已知因變量與一個或多個自變量之間關系進行回歸分析,構建數學模型探尋變量之間關系,進而預測可能存在變量的回歸分析過程.其中概率回歸,作為估計變量之間非線性關系的一種有效方法,可以利用變量之間的條件概率分布估計其可能存在的非線性關系.然而,這些回歸分析都需建立在具有足夠觀察數據的基礎上,進行統計學上的數據分析,顯然這并不符合大多數現實狀況.近年來,針對樣本數據過少的問題學者們都嘗試過不同的解決方法.譚少卿針對計算機視覺領域中少樣本目標識別問題,提出了具備快速學習能力的少樣本學習深度網絡,并通過引入衡量圖像像素級別差的L1損失函數加快網絡收斂速度,從而提升對少樣本的識別精度和識別率[1].賈宇峰在生成對抗網絡數據增強方法的基礎上,提出一種基于監督學習的條件自我注意生成對抗網絡數據增強方法,通過額外信息指導生成網絡構造數據,提升視覺樣本特征與語義表征之間的兼容度,進而能夠合成高質量的目標類別樣本特征,大幅降低度量學習的難度[2—3].作為概率回歸模型中的一種機器學習算法,高斯過程回歸(Gaussian Process Regression, GPR)是使用高斯過程先驗對數據進行回歸分析的非參數模型.它將點與點之間同質性的度量作為核函數,根據輸入的訓練數據預測未知點值的回歸模型[4—7].近年來,通過利用最大信息向量機(Maximum Informative Vector Machine,MIVM)等方法,GPR也可在多模態稀疏樣本下獲得較好的回歸預測結果,但是高斯過程的計算需求是以訓練集的立方形式進行增長的,這加大了GPR在多模態數據樣本的計算代價[8—9].

作為一種利用生成網絡與判別網絡進行對抗學習,互相權衡的深度學習網絡[10],生成對抗網絡(Generative Adversarial Network,GAN)避免了許多難以處理的近似概率計算問題,在圖像、分類、自然語言處理等領域有廣泛應用.基于GAN的各種衍生網絡中,條件生成對抗網絡(Conditional Generative Adversarial Network,CGAN)原本作為指定圖像處理生成對抗網絡,通過在生成網絡與對抗網絡的輸入端添加標簽作為條件概率進入深度學習網絡,從而指定網絡生成所需的標簽數據[11].鑒于CGAN網絡自身條件概率分布的特性,可以認為其生成網絡是近似回歸模型的隱層網絡,因而也可以認為是概率回歸模型[12].由于自身的神經網絡特性,CGAN可以有效地逼近相對復雜的擬合過程.同時,利用生成網絡與判別網絡的對抗特性可最小化兩者的輸出誤差,直接應用于回歸問題中[4].由此可見,CGAN在解決GPR多模態稀疏樣本問題的同時,簡化了擬合過程所需要的復雜計算過程,并可在此基礎上利用貝葉斯理論,將概率回歸中的邊緣化方法改為隨機梯度哈密頓量蒙特卡羅等其他方式[13],進一步優化CGAN網絡以達到理想的預測結果.

本文提出一種基于CGAN的稀疏樣本概率回歸預測模型.將GPR作為概率回歸模型的預測精度基準線,通過對比實驗研究影響CGAN模型的網絡因素,構建適用于稀疏樣本回歸預測的CGAN模型.

1 實驗環境設置

本文實驗采用4種非線性函數構成模擬數據,分別為指數型、異向型、蝴蝶型以及混合多模態型,其中混合多模態型由1組異向型函數和3組線性函數疊加而成.此外,本實驗中的誤差衡量標準選擇均方誤差(Mean Square Error,MSE)以及平均絕對誤差(Mean Absolute Error,MAE).MSE是真實值與預測值之差平方后的求和平均,可以反映算法的魯棒性.MAE是真實值與預測值之差的絕對值,可以更好地反映預測值誤差的實際情況.此外,為了模擬數據稀疏的特征狀況,通過隨機丟棄部分特征和減少訓練樣本數,將完整的非線性模擬數據變為相對稀疏的數據樣本.圖1為4種非線性原始數據與稀疏數據分布對比示意圖,左側是樣本量為1 000的原始數據分布圖,右側是經過稀疏處理后樣本量為200的數據分布圖.

圖1 4種非線性原始數據與稀疏數據對比分布示意圖

2 GPR回歸預測基準線

作為本文概率回歸模型的預測精度基準線,GPR是一種非參數的貝葉斯回歸方法.通過貝葉斯推斷尋找下列方程的回歸過程:

y=wx+ε.

(1)

其中貝葉斯推斷需滿足

(2)

式中p(w|y,x)為后驗分布.根據貝葉斯推斷可知先驗分布的似然估計是后驗概率的邊緣似然率.因此,針對可能的未知樣本x*的預測分布可以表示為

(3)

這里為了簡化計算過程,假設先驗和似然均為高斯過程.從而針對未知樣本x*的預測分布也為高斯過程,進而可以使用均值獲得預測點,并使用方差獲得不確定性量化.GPR是在貝葉斯推斷的基礎上,在函數空間上指定先驗,并使用訓練數據計算后驗,最后再利用(3)式計算未知樣本x*的后驗預測分布.因此在高斯過程回歸中,首先假定一個由均值m(x)、協方差函數k(x,x′)、標簽分布以及噪聲分布組成的高斯過程:

(4)

其中高斯過程就像一個無限維的多元高斯分布,其數據集標簽的任何集合均服從聯合高斯分布.繼而從高斯過程先驗出發,訓練樣本和測試樣本的集合同樣也服從聯合多元高斯分布:

(5)

式中K為協方差核矩陣,其元素對應于觀測評估的協方差函數.至此,高斯過程回歸預測邏輯已經完成,均值m(x)以及協方差函數k(x,x′)的選擇將決定GPR的先驗模型選擇.本實驗選取Matern32作為GPR的核函數. 作為徑向基核函數(Radial Basis Function, RBF)的泛化內核,Matern內核的本質是一種類似于RBF的恒定核函數,通過附加參數控制核函數的平滑程度,其協方差公式為

k(x,x′)=

(6)

式中:d(·,·)為歐氏距離;Kv(·)為修改后的貝塞爾函數;Γ(·)是伽馬函數;參數v為32.當v接近無窮大時,Matern32協方差公式與RBF一致.由此可見,Matern32核函數可有效控制學習函數的平滑度,進而使得底層函數具有更好的關聯屬性.至此,基于Matern32的GPR算法流程可表述如下.

算法1 基于Matern32核函數的GPR算法

高斯過程回歸,徑向基核函數算法流程:

(Ⅰ)條件:訓練樣本集和預測樣本集分別為

Dn={(xi,yi)|i=1,2,…,n},

(Ⅱ)目標:尋找函數y=f(x).

(Ⅲ)算法流程:

Step 1 指定先驗方程

p(f1:n)=GP(m(·),k(·,·)),

其中

Step 2 指定似然方程

Step 3 代入訓練數據

Dn={(xi,yi)|i=1,2,…,n}.

Step 4 優化超參數

Step 5 通過超參數θ*訓練先驗概率p(f1:n)和似然方程p(y1:n|f1:n).

Step 6 代入新的樣本

Step 7 計算后驗分布

p(fnew|Dn)~N(μnew|Dn),σ2(xnew|Dn).

根據算法1分別獲得4種非線性模擬數據基于Matern32核函數的GPR回歸預測結果損失函數值(表1).

對于蝴蝶型、異向型以及指數型模擬數據樣本,基于Matern32的GPR回歸預測精度隨著樣本數的減少并無顯著變化;對于混合多模態模擬數據樣本,GPR的回歸預測精度均較差.這是由于以Matern32為核函數的GPR并沒有學習到多模態稀疏樣本的本質特征,反而當樣本特征范圍下降時精度才有所提升.由此可見,多模態稀疏樣本對GPR的預測結果影響較大.

3 CGAN回歸預測比較

2014年Mirza等人[11]提出CGAN網絡框架,期望通過代入標簽數據以約束生成網絡隨意生成數據的行為,進而在圖像生成領域簡化獲得指定圖像的過程.額外標簽信息在CGAN網絡內作為約束信息控制生成器生成的圖像樣本,也作為額外信息提供判別網絡的標準,其本質為基于標簽信息的條件概率分布約束過程.這種約束生成標簽數據的生成對抗網絡框架如圖2所示.

圖2 條件生成對抗網絡生成標簽數據示意圖

通過加入條件概率分布代替原來單純的噪聲和樣本分布,可表示為

(7)

Εz~pz(z)[log(1-D(G(z|x)|x))]].

(8)

將pg作為噪聲和y經過生成網絡生成的隱性條件分布.(8)式可進一步改寫為

Εy~pg(y|x)[log(1-D(y|x))]].

(9)

根據(9)式可知,相較于GAN而言,CGAN的樣本同時約束生成網絡和判別網絡.因此,生成網絡可以通過學習pg(y|x)的條件概率分布,近似逼近pd(y|x)原始數據的條件概率分布.所以,就回歸問題而言,生成網絡將噪聲z和條件變量x作為輸入,判別網絡將樣本y和條件變量x作為輸入.由此,CGAN概率回歸算法流程可表述如下.

算法2 基于梯度下降的CGAN概率回歸算法

(Ⅰ)條件:訓練樣本集和測試樣本集分別為

Dn={(xi,yi)|i=1,2,…,n},

(Ⅲ) 算法流程(分為判別網絡與生成網絡):

(ⅰ)判別網絡

Step 1 根據批次m獲得先驗噪聲

pg(z)={z(1),z(2),…,z(m)}.

Step 2 根據批次m獲得原始數據

pdatay=(y1,y2,…,ym).

Step 3 根據批次m獲得原始數據

(ⅱ)生成網絡

Step 1 根據批次m獲得先驗噪聲

pg(z)={z(1),z(2),…,z(m)}.

Step 2 根據批次m獲得原始數據

pdatax=(x1,x2,…,xm).

根據算法2可以得到改進的條件生成對抗網絡模型示意圖(圖3).

圖3 改進的條件生成對抗網絡模型示意圖

從圖3可知,相較于條件生成對抗網絡,標簽信息和噪聲作為生成網絡的輸入信息,通過隱層后,在生成網絡生成數據樣本.在判別網絡中,生成數據樣本、原始樣本及標簽信息同時進入隱層,最后在判別器得出判別結果,進而優化生成網絡.由此可得改進的條件對抗生成網絡對4種非線性模擬數據的回歸預測結果(圖4).

圖4 基于改進的CGAN模型的非線性模擬數據的回歸預測結果示意圖

圖4中,左側為稀疏后4種非線性數據樣本的分布示意圖,右側為經過訓練后噪聲經過生成網絡生成的預測結果示意圖.由圖4可以看出,對于指數型、異向型以及蝴蝶型預測結果,改進的CGAN模型的預測結果與原始樣本數據的分布近似,均在各種類型數據的覆蓋范圍之內.但是對于混合多模態型模擬數據,改進的CGAN模型的預測結果相較于原始數據不存在斷層,預測精度相較于指數型、異向型和蝴蝶型并不精準,但是相較于GPR而言具有更好的精度.這是由于混合多模態樣本數據由多組非線性樣本重疊構成,與真實數據的樣本分布較為近似.在這種混合多模態樣本分布中,GPR的核函數選取影響其收斂的區間,進而影響預測精度.但是CGAN模型的深度網絡構造可以有效減少由核函數選取帶來的影響,通過深度學習網絡模擬構建函數過程,可以近似模擬較為復雜的樣本分布空間,進而具有較為準確的預測結果.

定量分析結果見表2.由表2可知CGAN模型的4種非線性回歸預測損失函數值均小于GPR模型.

表2 CGAN與GPR的4種非線性回歸預測模型損失函數值

4 結論

本文把基于Matern32核函數的GPR作為概率回歸模型在稀疏樣本回歸預測中的基準線,通過分析CGAN模型與概率回歸模型的構造,提出利用CGAN解決稀疏樣本回歸預測問題.通過對比實驗發現,本文提出的CGAN相較于GPR可以有效提高稀疏樣本在回歸預測中的預測精度.

猜你喜歡
高斯標簽模態
數學王子高斯
天才數學家——高斯
無懼標簽 Alfa Romeo Giulia 200HP
不害怕撕掉標簽的人,都活出了真正的漂亮
標簽化傷害了誰
國內多模態教學研究回顧與展望
基于多進制查詢樹的多標簽識別方法
基于HHT和Prony算法的電力系統低頻振蕩模態識別
有限域上高斯正規基的一個注記
由單個模態構造對稱簡支梁的抗彎剛度
91香蕉高清国产线观看免费-97夜夜澡人人爽人人喊a-99久久久无码国产精品9-国产亚洲日韩欧美综合