首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tf.contrib.seq2seq.gather_tree是如何工作的?

tf.contrib.seq2seq.gather_tree是如何工作的?
EN

Stack Overflow用户
提问于 2018-01-03 20:49:14
回答 1查看 325关注 0票数 1

contrib.seq2seq中的gather_tree究竟是如何工作的?我可以看到它采用了预测的ids和梁的父ids,并以某种方式返回了最终的梁,但实际上在引擎盖下面发生了什么?似乎没有任何Python代码库可供我检查以找出答案。API不是很有说明性;

有没有tf.contrib.seq2seq.gather_tree的代码源?我使用的是TensorFlow 1.3,但深入了解gen_beam_search_ops.py似乎没有什么帮助。

EN

回答 1

Stack Overflow用户

发布于 2018-03-20 17:56:07

具体代码如下:

代码语言:javascript
复制
def gather_tree_py(values, parents):
  """Gathers path through a tree backwards from the leave nodes. Used
  to reconstruct beams given their parents."""

  beam_length = values.shape[0]
  num_beams = values.shape[1]
  res = np.zeros_like(values)
  res[-1, :] = values[-1, :]
  for beam_id in range(num_beams):
    parent = parents[-1][beam_id]
    for level in reversed(range(beam_length - 1)):
      res[level, beam_id] = values[level][parent]
      parent = parents[level][parent]
  return np.array(res).astype(values.dtype)


def gather_tree(values, parents):
  """Tensor version of gather_tree_py"""

  res = tf.py_func(
      func=gather_tree_py, inp=[values, parents], Tout=values.dtype)
  res.set_shape(values.get_shape().as_list())
  return res

github: seq2seq beam_search

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/48077768

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档