(input_tensor.shape))
if ndims == 2: 维度恰好2,直接返回
return input_tensor
width = input_tensor.shape[-1] 获取最后一维(倒数第一维)的大小
output_tensor = tf.reshape(input_tensor, [-1, width]) 最后一维不变,前面的其他维度自适应(相乘)
return output_tensor
■二维变多维
def reshape_from_matrix(output_tensor, orig_shape_list):
"""Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
if len(orig_shape_list) == 2:
return output_tensor
output_shape = get_shape_list(output_tensor)
orig_dims = orig_shape_list[0:-1] 将原始形状 去除最后一维
width = output_shape[-1] 宽度为最后一维的大小
return tf.reshape(output_tensor, orig_dims + [width])
■秩的断言
def assert_rank(tensor, expected_rank, name=None):
"""Raises an exception if the tensor rank is not of the expected rank.
如果对不上,就报错
Args:参数:张量,期望的秩,名称(用于打印报错信息)
tensor: A tf.Tensor to check the rank of.
expected_rank: Python integer or list of integers, expected rank.
name: Optional name of the tensor for the error message.
Raises:
ValueError: If the expected shape doesn't match the actual shape.
"""
if name is None: 如果没有指定名称,则取张量的变量名称
name = tensor.name
expected_rank_dict = {}
if isinstance(expected_rank, six.integer_types): 如果指定的秩是个整数
expected_rank_dict[expected_rank] = True
else:
for x in expected_rank:如果指定的秩是一个列表(多个秩)【BERT模型源码解析】
推荐阅读
- 【lwip】11-UDP协议&源码分析
- 硬核剖析Java锁底层AQS源码,深入理解底层架构设计
- SpringCloudAlibaba 微服务组件 Nacos 之配置中心源码深度解析
- Seata 1.5.2 源码学习
- MindStudio模型训练场景精度比对全流程和结果分析
- .NET 源码学习 [数据结构-线性表1.2] 链表与 LinkedList<T>
- Redisson源码解读-公平锁
- OpenHarmony移植案例: build lite源码分析之hb命令__entry__.py
- 【深入浅出 Yarn 架构与实现】1-2 搭建 Hadoop 源码阅读环境
- JVM学习笔记——内存模型篇