我正在学习一个文本生成示例https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/charmodelling/generatetext/GenerateTxtCharCompGraphModel.java。lstm网络的输出是概率分布,据我所知,这是一个双数组,其中每个值都显示与数组中索引对应的字符的概率。因此,我无法理解从发行版中获取字符索引的以下代码:
/** Given a probability distribution over discrete classes, sample from the distribution
* and return the generated class index.
* @param distribution Probability distribution over classes. Must sum to 1.0
*/
static int sampleFromDistribution(double[] distribution, Random rng){
double d = 0.0;
double sum = 0.0;
for( int t=0; t<10; t++ ) {
d = rng.nextDouble();
sum = 0.0;
for( int i=0; i<distribution.length; i++ ){
sum += distribution[i];
if( d <= sum ) return i;
}
//If we haven't found the right index yet, maybe the sum is slightly
//lower than 1 due to rounding error, so try again.
}
//Should be extremely unlikely to happen if distribution is a valid probability distribution
throw new IllegalArgumentException("Distribution is invalid? d="+d+", sum="+sum);
}我们似乎得到了一个随机值。为什么我们不选择值最高的索引呢?如果我想选择的不是一个,而是两个或三个最有可能的下一个字符,我该怎么办?
发布于 2021-09-24 12:58:05
此函数从发行版中抽取样本,而不是简单地返回最可能出现的字符类。
这也意味着您没有得到最有可能的字符,相反,您得到的是一个随机字符,其概率是给定的概率分布定义的。
首先从统一分布(rng.nextDouble())中获得0到1之间的随机值,然后找到该值在给定分布中的位置。
你可以想象它是这样的(如果你的字母表中只有a到f):
[ a | b | c | d | e | f ]
0.0 0.3 0.5 1.0如果所绘制的随机值刚刚超过0.5,它将生成一个e,如果它只是小于该值,它将是一个d。
每个字母根据其在分布中的权重,在0到1之间占据这一行的比例空间。
https://stackoverflow.com/questions/69281314
复制相似问题