tensor比较大小函数
参考:https://blog.csdn.net/qq_50001789/article/details/128973672 (opens new window)
# torch.ge、torch.gt、torch.le、torch.lt、torch.ne、torch.eq
torch.ge(input, other, *, out=None) → Tensor
torch.gt(input, other, *, out=None) → Tensor
torch.le(input, other, *, out=None) → Tensor
torch.lt(input, other, *, out=None) → Tensor
torch.ne(input, other, *, out=None) → Tensor
torch.eq(input, other, *, out=None) → Tensor
1
2
3
4
5
6
2
3
4
5
6
功能:
torch.ge:实现大于等于(≥)运算
torch.gt:实现大于(>)运算
torch.le:实现小于等于(≤)运算
torch.lt:实现小于(<)运算
torch.ne:实现不等于(≠)运算 输入:
input
:待比较的数组other
:比较数值,可以是数组,也可以是一个数。tensor
或float
格式
输出:
布尔张量,尺寸和input
相同,当input
和other
元素之间符合运算时,对应位置元素为True
,否则为Flase
。
注:
- 第二个参数可以是一个数字,也可以是一个张量数组,只要与第一个参数满足广播条件即可;
- 也可以通过tensor加后缀的形式实现,如a.ge,a相当于input,即待比较的数组;
- 如果输入的是数组,则必须是tensor类型
代码案例:
import torch
a=torch.arange(5)
b=torch.tensor(3)
print(a)
print(b)
print(torch.ge(a,b))
print(torch.gt(a,b))
print(torch.le(a,b))
print(torch.lt(a,b))
print(torch.ne(a,b))
print(torch.eq(a,b))
1
2
3
4
5
6
7
8
9
10
11
2
3
4
5
6
7
8
9
10
11
输出
tensor([0, 1, 2, 3, 4])
tensor(3)
# 大于等于
tensor([False, False, False, True, True])
# 大于
tensor([False, False, False, False, True])
# 小于等于
tensor([ True, True, True, True, False])
# 小于
tensor([ True, True, True, False, False])
# 不等于
tensor([ True, True, True, False, True])
# 等于
tensor([False, False, False, True, False])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
2
3
4
5
6
7
8
9
10
11
12
13
14
官方文档
编辑 (opens new window)
上次更新: 2024/05/30, 07:49:34