前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Python|神经网络框架torch包[1]

Python|神经网络框架torch包[1]

作者头像
福贵
发布2020-05-20 23:05:10
2K0
发布2020-05-20 23:05:10
举报
文章被收录于专栏:菜鸟致敬菜鸟致敬

torch包主要是用于多维张量的数据结构和用于张量的数学操作。除此之外,还提供了许多用于张量有效序列化和任意类型的工具,还有一些其他相关的工具。

torch还有一个cuda的版本,如果NVIDIA的算力>=3.0,就可以使用。

张量

  1. torch.is_tensor(obj)->bool 返回对象是否为pytorch的张量。输入参数obj是需要判断的对象。
  2. torch.is_storage(obj)->bool 返回对象是否为pytorch的存储对象。输入参数obj是需要判断的对象。
  3. torch.is_comlplex(input)->bool 返回输入数是否为复数,pytorch的复数类型有torch.complex64和torch.complex128。输入参数input是pytorch的张量。
  4. torch.isfloatingpoint(input)->bool 返回输入数是否为浮点数,pytorch的浮点数类型有torch.float64,torch.float32和torch.float16。输入参数input是pytorch的张量。
  5. torch.setdefaultdtypr(d) 把默认的浮点数类型设置为d,torch.tensor()默认使用d类型。在一开始,d的类型是torch.float32。d是torch.dtype中的值。

例子:

代码语言:python
复制
>>> torch.tensor([1.2, 3]).dtype           # 初始化默认类型为float32
>>> torch.set_default_dtype(torch.float64)
>>> torch.tensor([1.2, 3]).dtype           # 默认新的浮点类型float64
torch.float64
  1. torch.getdefaultdtype()->torch.dtype 返回当前的默认浮点类型。

例子:

代码语言:python
复制
>>> torch.get_default_dtype()  # 初始默认浮点类型float32
torch.float32
>>> torch.set_default_dtype(torch.float64)
>>> torch.get_default_dtype()  # 现在默认改为float64
torch.float64
>>> torch.set_default_tensor_type(torch.FloatTensor)  # 设置张量类型也会影响torch.FloatTensor
>>> torch.get_default_dtype()  # 现在默认float32
torch.FloatTensor
torch.float32
  1. torch.setdefaulttensor_type(t) 输入参数t为张量数据的浮点类型。用于设置张量数据的默认类型。
  2. torch.numel(input)->int,输入input为张量形式。返回int,为元素的总数。例子:
代码语言:javascript
复制
>>> a = torch.randn(1, 2, 3, 4, 5)
>>> torch.numel(a)
120
>>> a = torch.zeros(4,4)
>>> torch.numel(a)
16
  1. torch.setprintoptions(precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None, scimode=None) 设置打印的参数。

参数:

  • precision:浮点数输出的精度,默认4
  • threshold:超过数目则只打印数组的总览,默认1000
  • edgeitems:总览打印时每个维度开始与结束时打印的数目,默认3
  • linewidth:行宽,指一行打印的数目,默认80,当指打印总览的时候无效
  • profile:默认Sane,格式输出的样式
  • sci_mode:是否启用科学计数法,bool类型。默认None
  1. torch.setflushdenormal(mode)->bool 禁用在cpu进行异常的浮点运算,如果系统支持冲洗异常数,返回True。setflushdenormal()仅在x86架构支持SSE3。输入mode参数为bool。例子:
代码语言:python
复制
>>> torch.set_flush_denormal(True)
True
>>> torch.tensor([1e-323], dtype=torch.float64)
tensor([ 0.], dtype=torch.float64)
>>> torch.set_flush_denormal(False)
True
>>> torch.tensor([1e-323], dtype=torch.float64)
tensor(9.88131e-324 * 
      [ 1.0000], dtype=torch.float64)
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-05-19,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Python与MySQL 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 张量
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档