torch.tensor参数(torch.nn.linear函数的参数)

作者:电脑培训网 2024-05-06 22:25:39 343

[torch.nn.Parameter]参数介绍及使用

文章目录

torch.nn.Parameter基本介绍参数构造参数访问参数初始化使用内置初始化自定义初始化参数绑定参考torch.nn.Parameter

torch.tensor参数(torch.nn.linear函数的参数)

基本介绍

torch.nn.Parameter它是继承自torch.Tensor的子类,主要功能是作为nn.Module中的可训练参数。它与torch.Tensor的区别在于,nn.Parameter会自动被认为是模块的可训练参数,即添加到parameter()迭代器中。

具体格式如下:

torch.nn.parameter.Parameter

其中data为要传入的Tensor,requires_grad默认为True。

事实上,torch.nn中提供的模块中的参数都是nn.Parameter类,例如:

module=nn.Linear(3,3)type(module.weight)#torch.nn.parameter.Parametertype(module.bias)#torch.nn.parameter.Parameter

参数构造

nn.Parameter可以看做是一个类型转换函数,它将一个不可训练类型的Tensor转换为可训练类型的参数,并将这个参数绑定到这个模块上。nn.Parameter()添加的参数将被添加到Parameters列表中,它会被发送到优化器随着训练一起学习和更新。

此时调用parameters()方法就会显示参数。读者可以自行体验以下两端代码:

'''代码片段1'''classNet(nn.Module):def__init__(self):super().__init__()self.weight=torch.randn(3,3)self.bias=torch.randn(3)defforward(self,inputs):passnet=Net()print(list(net.parameters()))#[]'''代码片段2'''classNet(nn.Module):def__init__(self):super().__init__()self.weight=**nn.Parameter**(torch.randn(3,3))#将张量转换为参数类型self.bias=**nn.Parameter**(torch.randn(3))defforward(self,inputs):passnet=Net()print(list(**net.parameters()**))#显示参数#[参数包含:#tensor([[-0.4584,0.3815,-0.4522],#[2.1236,0.7928,-0.7095],#[-1.4921,-0.5689,-0.2342]],requires_grad=True),参数包含:#tensor([-0.6971,-0.7651,0.7897],requires_grad=True)]

nn.Parameter相当于将传入的数据包装成一个参数。如果你想直接访问/使用数据而不是参数本身,你可以调用nn.Parameter对象上的data属性:

a=torch.tensor([1,2,3]).to(torch.float32)param=nn.Parameter(a)print(param)#参数包含:#tensor([1.2.3.],require_grad=True)print(param.data)#张量([1.2.3.])

参数访问

nn.Module中有一个**state_dict()**方法,它将以字典的形式返回模块的所有状态,包括模块的参数和持久缓冲区。字典的键是相应参数/缓冲区的名称。

由于所有模块都继承了nn.Module,因此我们可以在任何模块上调用state_dict()方法来查看状态:

Linear_layer=nn.Linear(2,2)print(linear_layer.state_dict())#OrderedDict([('权重',张量([[0.2602,-0.2318],#[-0.5192,0.0130]])),('偏差',张量([0.5890,0.2476]))])print(linear_layer.state_dict().keys())#odict_keys(['权重','偏差'])

对于线性层,除了state_dict()之外,我们还可以直接调用对应的属性,如下:

Linear_layer=nn.Linear(2,1)print(linear_layer.weight)#参数包含:#tensor([[-0.1990,0.3394]],requires_grad=True)print(linear_layer.bias)#参数包含:#tensor([0.2697],requires_grad=True)

需要注意的是,上面返回的都是参数对象。如果需要使用数据,可以调用data

属性。

参数初始化

使用内置初始化

对于下面的单隐藏层网络,我们希望将内置初始化器应用于两个线性层

类Net(nn.Module):def__init__(self):super().__init__()self.layers=nn.Sequential(nn.Linear(3,2),nn.ReLU(),nn.Linear(2,3),)defforward(self,X):returnself.layers(X)

假设权重从N(0,1)中采样,并且bias全部初始化为0,初始化代码如下:

definit_normal(module):#需要判断子模块是否属于nn.Linear类,因为激活函数没有参数iftype(module)==nn.Linear:nn.init.normal_(module.weight,平均值=0,标准差=1)nn.init.zeros_(module.bias)

net=Net()net.apply(init_normal)forparaminnet.parameters():print(param)#参数包含:#tensor([[-0.3560,0.8078,-2.4084],#[0.1700,-0.3217,-1.3320]],requires_grad=True)#参数包含:#tensor([0.0.],requires_grad=True)#参数包含:#tensor([[-0.8025,-1.0695],#[-1.7031,-0.3068],#[-0.3499,0.4263]],requires_grad=True)#参数包含:#tensor([0.0.0.],requires_grad=True)

调用net上的apply方法会递归地将init_normal函数应用到其下的所有子模块。

自定义初始化

如果我们想要自定义初始化,例如使用下面的分布来初始化网络的权重:

defmy_init:如果type==nn.linear:nn.init.uniform_mask=module.weight.data.abs=面具

net=Net()net.apply(my_init)forparaminnet.parameters():print(param)#参数包含:#tensor([[-0.0000,-5.9610,8.0000],#[-0.0000,-0.0000,7.6041],quirenes_grad=true)commatetaing:tensorcommatecomenterage:tensor)

如果我们希望第二个隐藏层和第三个隐藏层共享参数,我们可以做到这一点:

共享=nn.Linear(8,8)net=nn.Sequential(nn.Linear(4,8),nn.ReLU(),共享,nn.ReLU(),共享,nn.ReLU(),nn.Linear(8,1))

参考

Pytorch研究注释-顺序类,参数管理和GPU_LAREGES的博客CSDNBlog_sequential类

TORCH.NN中国文档

Python的torch.nn.参数初始化方法_郝大侠的博客-CSDN博客_torch.nn.参数初始化

相关推荐

  • 培训机构学习linux,培训班linux

    培训机构学习linux,培训班linux

    大家好,今天小编关注到一个比较有意思的话题,就是关于培训机构学习linux的问题,于是小编就整理了3个相关介绍培训机构学习linux的解答,让我们一起看看吧。l…

    培训机构学习linux,培训班linux 2024-10-01 03:26:19
  • 云计算培训linux机构,linux云计算培训班价格

    云计算培训linux机构,linux云计算培训班价格

    大家好,今天小编关注到一个比较有意思的话题,就是关于云计算培训linux机构的问题,于是小编就整理了3个相关介绍云计算培训linux机构的解答,让我们一起看看吧…

    云计算培训linux机构,linux云计算培训班价格 2024-06-18 20:06:54
  • 以色列印度培训机构,以色列christina培训

    以色列印度培训机构,以色列christina培训

    大家好,今天小编关注到一个比较有意思的话题,就是关于以色列印度培训机构的问题,于是小编就整理了4个相关介绍以色列印度培训机构的解答,让我们一起看看吧。亚训有哪些…

    以色列印度培训机构,以色列christina培训 2024-06-07 07:44:05
  • 机智云 esp8266 arduino(esp8266接入机智云教程)

    机智云 esp8266 arduino(esp8266接入机智云教程)

    第十章STM32+ESP8266连接机智云,实现小型物联网智能家居项目前言最近有很多朋友私信,要求我推出一个关于远程控制以及通过APP获取传感器信息的实验教程。…

    机智云 esp8266 arduino(esp8266接入机智云教程) 2024-05-07 09:36:57
  • type = module(typedefintstatus)

    type = module(typedefintstatus)

    type="module"你知道,但是type="importmap"你知道有新系列:Vue2和Vue3技能手册如果你有梦想,有实用信息,就微信搜索【大世界运动…

    type = module(typedefintstatus) 2024-05-07 08:48:12
热门推荐

猜你喜欢