Gumbel-sofmax采样技巧

以强化学习为例,假设网络输出的三维向量代表三个动作(前进、停留、后退)在下一步的收益,value=[-10,10,15],那么下一步我们就会选择收益最大的动作(后退)继续执行,于是输出动作[0,0,1]。选择值最大的作为输出动作,这样做本身没问题,但是在网络中这种取法有个问题是不能计算梯度,也就不能更新网络。

softmax采样

这时通常的做法是加上softmax函数,把向量归一化,这样既能计算梯度,同时值的大小还能表示概率的含义(多项分布)。

$\pi_k=\frac{e^{x_k}}{\sum_{i=1}^Ke^{x_i}}$

于是value=[-10,10,15]通过softmax函数后有σ(value)=[0,0.007,0.993],这样做不会改变动作或者说类别的选取,同时softmax倾向于让最大值的概率显著大于其他值,比如这里15和10经过softmax放缩之后变成了0.993和0.007,这有利于把网络训成一个one-hot输出的形式,这种方式在分类问题中是常用方法。

但这样就不会体现概率的含义了,因为σ(value)=[0,0.007,0.993]与σ(value)=[0.3,0.2,0.5]在类别选取的结果看来没有任何差别,都是选择第三个类别,但是从概率意义上讲差别是巨大的。

很直接的方法是依概率采样完事了,比如直接用np.random.choice函数依照概率生成样本值,这样概率就有意义了。所以,经典的采样方法就是用softmax函数加上轮盘赌方法(np.random.choice)。但这样还是会有个问题,这种方式怎么计算梯度?不能计算梯度怎么更新网络?

1
2
3
4
5
6
7
8
9
10

def sample_with_softmax(logits, size):

# logits为输入数据

# size为采样数

pro = softmax(logits)

return np.random.choice(len(logits), size, p=pro)

基于gumbel-max的采样

对于K维概率向量 $\alpha$,对 $\alpha$对应的离散变量 $x_i=log(\alpha_i)$ 添加Gumbel噪声,再取样:

$x=argmax_{i}(log(\alpha_i)+G_i)$

其中,$G_i$ 是独立同分布的标准Gumbel分布的随机变量, 标准Gumbel分布的CDF为 $F(x)=e^{-e^{-x}}$, 所以 $G_i$ 可以通过Gumbel分布求逆从均匀分布生成,即 $G_i=-log(-log(U_i)), U_i - U(0,1)$, 这样就得到了基于gumbel-max的采样过程:

  • 对于网络输出的一个K维向量v,生成K个服从均匀分布U(0,1)的独立样本 $\gamma_1,...,\gamma_K$;
  • 通过 $G_i$ 计算得到 $G_i$;
  • 对应相加得到新的值向量
  • 取最大值作为最终的类别
    1
    2
    3
    4
    5
    6

    def sample_with_gumbel_noise(logits, size):

    noise = sample_gumbel((size, len(logits))) # 产生gumbel noise

    return np.argmax(logits + noise, axis=1)

基于gumbel-softmax的采样

如果仅仅是提供一种常规 softmax 采样的替代方案, gumbel 分布似乎应用价值并不大。幸运的是,我们可以利用 gumbel 实现多项分布采样的 reparameterization(再参数化)。

在VAE中,假设隐变量(latent variables)服从标准正态分布。而现在,利用 gumbel-softmax 技巧,我们可以将隐变量建模为服从离散的多项分布。在前面的两种方法中,random.choice和argmax注定了这两种方法不可导,但我们可以将后一种方法中的argmax soft化,变为softmax。

$x=softmax((log(\alpha_i)+G_i)/temperature)$

temperature 是在大于零的参数,它控制着 softmax 的 soft 程度。温度越高,生成的分布越平滑;温度越低,生成的分布越接近离散的 one-hot 分布。训练中,可以通过逐渐降低温度,以逐步逼近真实的离散分布。

这样就得到了基于gumbel-max的采样过程:

  • 对于网络输出的一个K维向量v,生成K个服从均匀分布U(0,1)的独立样本 $\gamma_1,...,\gamma_K$;
  • 通过 $G_i$ 计算得到 $G_i$;
  • 对应相加得到新的值向量
  • 通过softmax函数计算概率大小得到最终的类别。

1
2
3
4
5
6
7
8

def differentiable_gumble_sample(logits, temperature=1):

noise = tf.random_uniform(tf.shape(logits), seed=11)

logits_with_noise = logits - tf.log(-tf.log(noise))

return tf.nn.softmax(logits_with_noise / temperature)

Gumbel分布

首先,我们介绍一样何为gumbel分布,gumbel分布是一种极值型分布。举例而言,假设一天内每次的喝水量为一个随机变量,它可能服从某个概率分布,记下这一天内喝的10次水的量并取最大的一个作为当天的喝水量值。显然,每天的喝水量值也是一个随机变量,并且它的概率分布即为 Gumbel 分布。实际上,只要是指数族分布,它的极值分布都服从Gumbel分布。

他的概率密度函数:

$f(s;\mu, \beta)=e^{-Z-e^{-Z}}, Z=\frac{X-\mu}{\beta}$

公式中,$\mu$ 是位置系数(Gumbel 分布的众数是 $\mu$),$\beta$ 是尺度系数(Gumbel 分布的方差是 $\frac{\pi^2}{6}\beta^2$)。

1
2
3
4
5
6

def gumbel_pdf(x, mu=0, beta=1):

z = (x - mu) / beta

return np.exp(-z - np.exp(-z)) / beta

为什么方法一与方法三生成一样的效果?

先定义一个多项分布,作出真实的概率密度图。再通过采样的方式比较各种方法的效果。这里定义了一个8类别的多项分布,其真实的密度函数如下左图。 首先我们直接根据真实的分布利用np.random.choice函数采样对比效果(实现代码放在文末) 左图为真实概率分布,右图为采用np.random.choice函数采样的结果(采样次数为1000)。可见效果还是非常好的,要是没有不能求梯度这个问题,直接从原分布采样是再好不过的。接着通过前述的方法添加Gumbel噪声采样,同时也添加正态分布和均匀分布的噪声作对比。(基于gumbel-max的采样) 可以明显看到Gumbel噪声的采样效果是最好的,正态分布其次,均匀分布最差。也就是说用Gumbel分布的样本点最接近真实分布的样本。 最后,我们基于gumbel-softmax做采样,左图设置temperature=0.1,经过softmax函数后得到的概率分布接近one-hot分布,用此概率分布对分类求期望值,得到结果为左图,可以较好地逼近方法一的采样结果;右图设置temperature=5,经过softmax函数后得到的概率分布接近均匀分布,再对分类求期望值,得到的结果集中在类别3、 4(中间的类别)。这和gumbel-softmax具备的性质是一致的,temperature控制着softmax的soft程度,温度越高,生成的分布越平滑(接近这里的均匀分布);温度越低,生成的分布越接近离散的one-hot分布。因此,训练时可以逐渐降低温度,以逐步逼近真实的离散分布。(基于gumbel-softmax的采样)

到此为此,我们也算用一组实验去解释了为什么方法二、方法三是可行的。具体的代码放在文末了,感兴趣的可以研究一下。

为什么使用Gumbel分布就可以逼近多项分布采样?

为什么它可以有这样的效果?为什么添加gumbel噪声就可以近似范畴分布(category distribution)采样。

我们来考虑一个问题,假设一共有K个类别,那么第k个类别恰好是最大的概率是多少?

对于一个K维的输出向量,每个维度的值记为,通过softmax函数可得,取到每个维度的概率为:

$\pi_k=\frac{e^{x_k}}{\sum_{i=1}^Ke^{x_i}}$

设 $x_k=log \alpha_k$,可以看出 $\alpha_k$即 $\pi_k$,这是直接用softmax得到的概率密度函数,它也可以换一种方式去说,对每个 $x_k$ 添加独立的标准Gumbel分布(尺度参数为1,位置参数为0)噪声,并选择值最大的维度作为输出,得到的概率密度同样为$\alpha_k$。

为什么再参数化(reparameterization tricks)就可以变得可导?

reparameterization tricks的思想是说如果我们能把一个复杂变量用一个标准变量来表示,比如 $Z=f(\gamma)$, 其中 $\gamma$ 服从N(0,1), 那么我们就可以用$\gamma$这个变量取代z。

这样做是有好处的,一方面在更新梯度时可以将随机变量提取出来,不影响对参数的更新;另一方面假如我们要依据 $p(z;\theta)$ 采样,然后再利用采样处的梯度修正p,这样两次的误差就会叠加,但现在只需要从一个分布非常稳定的random seed的分布中采样,比如N(0,1)所以noise小得多。

code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246

from scipy.optimize import curve_fit

import numpy as np

import matplotlib.pyplot as plt



n_cats = 8

n_samples = 1000

cats = np.arange(n_cats)

probs = np.random.randint(low=1, high=20, size=n_cats)

probs = probs / sum(probs)

logits = np.log(probs)



def plot_probs(): # 真实概率分布

plt.bar(cats, probs)

plt.xlabel("Category")

plt.ylabel("Original Probability")



def plot_estimated_probs(samples,ylabel=''):

n_cats = np.max(samples)+1

estd_probs,_,_ = plt.hist(samples,bins=np.arange(n_cats+1),align='left',edgecolor='white')

plt.xlabel('Category')

plt.ylabel(ylabel+'Estimated probability')

return estd_probs



def print_probs(probs):

print(probs)



samples = np.random.choice(cats,p=probs,size=n_samples) # 依概率采样



plt.figure()

plt.subplot(1,2,1)

plot_probs()

plt.subplot(1,2,2)

estd_probs = plot_estimated_probs(samples)

plt.tight_layout() # 紧凑显示图片

plt.savefig('/home/zhumingchao/PycharmProjects/matplot/gumbel1')



print('Original probabilities:\t',end='')

print_probs(probs)

print('Estimated probabilities:\t',end='')

print_probs(estd_probs)

plt.show()

######################################



def sample_gumbel(logits):

noise = np.random.gumbel(size=len(logits))

sample = np.argmax(logits+noise)

return sample

gumbel_samples = [sample_gumbel(logits) for _ in range(n_samples)]



def sample_uniform(logits):

noise = np.random.uniform(size=len(logits))

sample = np.argmax(logits+noise)

return sample

uniform_samples = [sample_uniform(logits) for _ in range(n_samples)]



def sample_normal(logits):

noise = np.random.normal(size=len(logits))

sample = np.argmax(logits+noise)

# print('old',sample)

return sample

normal_samples = [sample_normal(logits) for _ in range(n_samples)]



plt.figure(figsize=(10,4))

plt.subplot(1,4,1)

plot_probs()

plt.subplot(1,4,2)

gumbel_estd_probs = plot_estimated_probs(gumbel_samples,'Gumbel ')

plt.subplot(1,4,3)

normal_estd_probs = plot_estimated_probs(normal_samples,'Normal ')

plt.subplot(1,4,4)

uniform_estd_probs = plot_estimated_probs(uniform_samples,'Uniform ')

plt.tight_layout()

plt.savefig('/home/zhumingchao/PycharmProjects/matplot/gumbel2')



print('Original probabilities:\t',end='')

print_probs(probs)

print('Gumbel Estimated probabilities:\t',end='')

print_probs(gumbel_estd_probs)

print('Normal Estimated probabilities:\t',end='')

print_probs(normal_estd_probs)

print('Uniform Estimated probabilities:\t',end='')

print_probs(uniform_estd_probs)

plt.show()

#######################################



def softmax(logits):

return np.exp(logits)/np.sum(np.exp(logits))



def differentiable_sample_1(logits, cats_range, temperature=.1):

noise = np.random.gumbel(size=len(logits))

logits_with_noise = softmax((logits+noise)/temperature)

# print(logits_with_noise)

sample = np.sum(logits_with_noise*cats_range)

return sample

differentiable_samples_1 = [differentiable_sample_1(logits,np.arange(n_cats)) for _ in range(n_samples)]



def differentiable_sample_2(logits, cats_range, temperature=5):

noise = np.random.gumbel(size=len(logits))

logits_with_noise = softmax((logits+noise)/temperature)

# print(logits_with_noise)

sample = np.sum(logits_with_noise*cats_range)

return sample

differentiable_samples_2 = [differentiable_sample_2(logits,np.arange(n_cats)) for _ in range(n_samples)]



def plot_estimated_probs_(samples,ylabel=''):

samples = np.rint(samples)

n_cats = np.max(samples)+1

estd_probs,_,_ = plt.hist(samples,bins=np.arange(n_cats+1),align='left',edgecolor='white')

plt.xlabel('Category')

plt.ylabel(ylabel+'Estimated probability')

return estd_probs



plt.figure(figsize=(8,4))

plt.subplot(1,2,1)

gumbelsoft_estd_probs_1 = plot_estimated_probs_(differentiable_samples_1,'Gumbel softmax')

plt.subplot(1,2,2)

gumbelsoft_estd_probs_2 = plot_estimated_probs_(differentiable_samples_2,'Gumbel softmax')

plt.tight_layout()

plt.savefig('/home/zhumingchao/PycharmProjects/matplot/gumbel3')



print('Gumbel Softmax Estimated probabilities:\t',end='')

print_probs(gumbelsoft_estd_probs_1)

plt.show()

来自 https://blog.csdn.net/weixin_40255337/article/details/83303702