Softmax的核心思想
softmax ,说人话,就是从一堆数字中,找出一个最大值,就是找 max。
1.概率归一化,而非极端化
softmax的主要目的是将任意实数值转换为有效的概率分布,确保:
- 所有输出值都在 [0,1] 区间内
- 所有输出值的和等于1
- 保持原始值的相对大小关系
2.放大差异,但不是"无限接近"
让我用一个简单的例子来说明:
import torch
import torch.nn.functional as F
# 示例1:差异较小的logits
logits1 = torch.tensor([2.0, 2.1, 1.9])
probs1 = F.softmax(logits1, dim=0)
print("小差异logits:", logits1.tolist())
print("softmax输出:", probs1.tolist())
print()
# 示例2:差异较大的logits
logits2 = torch.tensor([5.0, 1.0, 1.0])
probs2 = F.softmax(logits2, dim=0)
print("大差异logits:", logits2.tolist())
print("softmax输出:", probs2.tolist())
print()
# 示例3:差异极大的logits
logits3 = torch.tensor([10.0, 1.0, 1.0])
probs3 = F.softmax(logits3, dim=0)
print("极大差异logits:", logits3.tolist())
print("softmax输出:", probs3.tolist())
运行结果大致为:
小差异logits: [2.0, 2.1, 1.9]
softmax输出: [0.307, 0.410, 0.283] # 比较均匀
大差异logits: [5.0, 1.0, 1.0]
softmax输出: [0.952, 0.024, 0.024] # 明显倾向最大值
极大差异logits: [10.0, 1.0, 1.0]
softmax输出: [0.9999, 0.00005, 0.00005] # 几乎全部概率给最大值
3.真正的核心思想
a)指数放大效应
# softmax公式: softmax(x_i) = exp(x_i) / Σexp(x_j)
- exp() 函数会指数级地放大原始差异
- 原始值差1,exp后差约2.7倍
- 原始值差2,exp后差约7.4倍
b)保持相对顺序
- 如果 x > x,那么 softmax(x) > softmax(x)
- 不会改变原始值的大小顺序
c)"软"选择而非"硬"选择
# 硬选择(one-hot):[0, 0, 1, 0]
# 软选择(softmax):[0.1, 0.2, 0.6, 0.1]
4.为什么叫"soft"max?
与传统的"hardmax"(argmax)相比:
- Hardmax: 只选出最大值,其他全为0 → [0, 0, 1, 0]
- Softmax: 给最大值最高概率,但其他值也保留小概率 → [0.1, 0.2, 0.6, 0.1]
5.实际应用中的意义
a)训练阶段
- 保留所有可能性,让模型学习细微差异
- 梯度可以流向所有参数,不会出现梯度为0的情况
b)推理阶段
- 提供每个类别的置信度
- 可以根据概率分布做更灵活的决策
6.总结
softmax的核心思想是:
- 概率化:将任意实数转为有效概率分布
- 差异放大:通过指数函数放大原始差异
- 软决策:保持所有选项的可能性,而非绝对的0或1
- 可导性:保证整个过程可微分,适合梯度下降
所以您的理解部分正确——它确实会增加区分度,让最大的值获得更高概率,但关键在于它是"软"的选择,不是绝对的0和1,这种"软性"正是它在深度学习中如此重要的原因。
import torch
import torch.nn.functional as F
# 示例1:差异较小的logits
logits1 = torch.tensor([2.0, 2.1, 1.9])
probs1 = F.softmax(logits1, dim=0)
print("小差异logits:", logits1.tolist())
print("softmax输出:", probs1.tolist())
print()
# 示例2:差异较大的logits
logits2 = torch.tensor([5.0, 1.0, 1.0])
probs2 = F.softmax(logits2, dim=0)
print("大差异logits:", logits2.tolist())
print("softmax输出:", probs2.tolist())
print()
# 示例3:差异极大的logits
logits3 = torch.tensor([10.0, 1.0, 1.0])
probs3 = F.softmax(logits3, dim=0)
print("极大差异logits:", logits3.tolist())
print("softmax输出:", probs3.tolist())
小差异logits: [2.0, 2.1, 1.9]
softmax输出: [0.307, 0.410, 0.283] # 比较均匀
大差异logits: [5.0, 1.0, 1.0]
softmax输出: [0.952, 0.024, 0.024] # 明显倾向最大值
极大差异logits: [10.0, 1.0, 1.0]
softmax输出: [0.9999, 0.00005, 0.00005] # 几乎全部概率给最大值
# softmax公式: softmax(x_i) = exp(x_i) / Σexp(x_j)
# 硬选择(one-hot):[0, 0, 1, 0]
# 软选择(softmax):[0.1, 0.2, 0.6, 0.1]