TensorFlow-静态图和PyTorch-动态图区别 | StriveZs的博客

TensorFlow-静态图和PyTorch-动态图区别

TensorFlow-静态图和PyTorch-动态图区别

最近在重新学习一遍pytorch,之前对于自动求导中的计算图的概念不是很清楚,这里从头看了一遍,有了解一下,简单的写一下自己的笔记。

PyTorch自动求导看起来非常像TensorFlow,这两个框架中,我们都定义了计算图,使用自动微分来计算梯度,但是两者之间最大的不同是TensorFlow的计算图是静态的,而PyTorch使用的是动态的计算图。
在TensorFlow中,我们定义计算图一次,然后后续就会重复执行这个相同的图,后面的话可能只是会提供不同的输入数据,而在PyTorch中,每一个前向通道(forward)定义一个新的计算图。

静态图的好处在于你可以预先对图进行优化。例如:一个框架可能要融合一些图的运算来提升效率,或者产生一个策略来将图分布到多个GPU或者机器上,如果重复使用相同的图,那么再重复运行一个图时,前期潜在的代价高昂的预先优化的消耗就会被分摊开。

静态图和动态图的一个区别是控制流。对于一些模型,我们希望对每个数据点执行不同的计算。例如:一个递归神经网络可能对每个数据点执行不同的时间步数,这个展开(unrolling)可以作为一个循环来实现。
对于一个静态图,循环结构要作为图的一部分,因此TensorFlow提供了运算符来把循环嵌入到图当中。对于动态图来说,情况更加简单,既然我们为每个例子即时创建计算图,我们可以使用普通的命令式控制流来为每个输入执行不同的计算。

tensorflow的forward只会根据第一次模型前向传播来构建一个静态的计算图, 后面的梯度自动求导都是根据这个计算图来计算的, 但是pytorch则不是, 它会为每次forward计算都构建一个动态图的计算图, 后续的每一次迭代都是使用一个新的计算图进行计算的.

PyTorch应用控制流的动态图实例

作为动态图(网络结构发生变化并不影响计算图计算梯度)和权重共享的一个例子,我们实现了一个非常奇怪的模型:一个全连接的ReLU网络,在每一次前向传播时,它的隐藏层的层数为随机1到4之间的数,这样可以多次重用相同的权重来计算。

因为这个模型可以使用普通的Python流控制来实现循环,并且我们可以通过定义转发时多次重用同一个模块来实现最内层的权重共享。

下面是例子的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
import random

# fixme: 定义网络
class DynamicNet(nn.Module):
def __init__(self, D_in, H, D_out):
"""
构造函数,在这里需要将网络的各个模块进行实例化, 并把他们作为成员变量
:params D_in: 输入维度
:params H: 隐藏层维度
:params D_out: 输出维度
"""
super(DynamicNet, self).__init__()

self.input_layer = nn.Linear(D_in, H)
self.hidden_layer = nn.Linear(H, H)
self.output_layer = nn.Linear(H, D_out)

self.relu = nn.ReLU()

def forward(self, x):
x1 = self.input_layer(x)
t_relu = self.relu(x1)

# 定义0-3个隐藏层,利用pytorch动态图的特征,这种做法是可行的
## 重复调用self.hidden_layer 0-3次,由于pytorch是采用动态图的,因此每一次forward都会创建一个新的动态图,不影响梯度计算
for _ in range(random.randint(0, 3)):
t_relu = self.relu(self.hidden_layer(t_relu))

pred = self.output_layer(t_relu)

return pred

# fixme: 参数配置
dtype = torch.float
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
N, D_in, H, D_out = 64, 1000, 100, 10 # N为批量大小,D_in是输入维度,H是隐藏层维度,D_out是输出层维度
learning_rate = 1e-4
epochs = 100

# fixeme: 创建输入和输出随机张量
input = torch.randn(N, D_in)
label = torch.randn(N, D_out)

# fixme: 实例化模型
model = DynamicNet(D_in, H, D_out)

# fixme: 损失函数的定义
loss_fn = torch.nn.MSELoss(reduction='sum')

# fixme: 使用torch.optim定义参数优化器
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate) # 这里用的是Adam

# fixme: 训练
for epoch in range(epochs):
# 前向过程
pred = model(input)

# 计算loss
loss = loss_fn(pred, label)
print('当前代数:{},当前loss为:{}'.format(epoch,loss.item()))

# 模型参数梯度置零
optimizer.zero_grad()

# loss反向传播
loss.backward()

# 更新权重
optimizer.step()
当前代数:0,当前loss为:702.3717651367188
当前代数:1,当前loss为:700.1735229492188
当前代数:2,当前loss为:697.996826171875
当前代数:3,当前loss为:701.37646484375
当前代数:4,当前loss为:745.0324096679688
当前代数:5,当前loss为:686.2260131835938
当前代数:6,当前loss为:729.3468017578125
当前代数:7,当前loss为:719.3234252929688
当前代数:8,当前loss为:700.3496704101562
当前代数:9,当前loss为:697.38720703125
当前代数:10,当前loss为:676.840576171875
当前代数:11,当前loss为:674.8969116210938
当前代数:12,当前loss为:667.5950927734375
当前代数:13,当前loss为:657.723388671875
当前代数:14,当前loss为:668.5684814453125
当前代数:15,当前loss为:691.20361328125
当前代数:16,当前loss为:629.1569213867188
当前代数:17,当前loss为:662.6659545898438
当前代数:18,当前loss为:611.5732421875
当前代数:19,当前loss为:602.4924926757812
当前代数:20,当前loss为:698.65625
当前代数:21,当前loss为:583.9902954101562
当前代数:22,当前loss为:655.1369018554688
当前代数:23,当前loss为:566.2965087890625
当前代数:24,当前loss为:557.3516845703125
当前代数:25,当前loss为:547.8927001953125
当前代数:26,当前loss为:650.374755859375
当前代数:27,当前loss为:648.9633178710938
当前代数:28,当前loss为:697.7493286132812
当前代数:29,当前loss为:515.4136962890625
当前代数:30,当前loss为:507.982421875
当前代数:31,当前loss为:500.00030517578125
当前代数:32,当前loss为:491.6020812988281
当前代数:33,当前loss为:482.8730163574219
当前代数:34,当前loss为:697.01220703125
当前代数:35,当前loss为:687.0211791992188
当前代数:36,当前loss为:638.8866577148438
当前代数:37,当前loss为:696.5703735351562
当前代数:38,当前loss为:636.4703369140625
当前代数:39,当前loss为:685.7159423828125
当前代数:40,当前loss为:685.2298583984375
当前代数:41,当前loss为:684.5957641601562
当前代数:42,当前loss为:695.6959228515625
当前代数:43,当前loss为:427.000732421875
当前代数:44,当前loss为:682.4891967773438
当前代数:45,当前loss为:419.1412048339844
当前代数:46,当前loss为:681.0149536132812
当前代数:47,当前loss为:694.716552734375
当前代数:48,当前loss为:406.95831298828125
当前代数:49,当前loss为:678.7310180664062
当前代数:50,当前loss为:623.3473510742188
当前代数:51,当前loss为:395.0804443359375
当前代数:52,当前loss为:693.6585083007812
当前代数:53,当前loss为:675.7344360351562
当前代数:54,当前loss为:618.6995239257812
当前代数:55,当前loss为:616.9414672851562
当前代数:56,当前loss为:614.5744018554688
当前代数:57,当前loss为:692.4542236328125
当前代数:58,当前loss为:609.0692138671875
当前代数:59,当前loss为:691.904052734375
当前代数:60,当前loss为:603.1943359375
当前代数:61,当前loss为:669.9990234375
当前代数:62,当前loss为:691.0020141601562
当前代数:63,当前loss为:594.264892578125
当前代数:64,当前loss为:591.1102294921875
当前代数:65,当前loss为:666.8950805664062
当前代数:66,当前loss为:665.9771728515625
当前代数:67,当前loss为:689.326171875
当前代数:68,当前loss为:688.93701171875
当前代数:69,当前loss为:688.5073852539062
当前代数:70,当前loss为:574.0647583007812
当前代数:71,当前loss为:661.115234375
当前代数:72,当前loss为:660.0462036132812
当前代数:73,当前loss为:566.465576171875
当前代数:74,当前loss为:360.7765197753906
当前代数:75,当前loss为:685.8043212890625
当前代数:76,当前loss为:357.83026123046875
当前代数:77,当前loss为:556.7740478515625
当前代数:78,当前loss为:684.478515625
当前代数:79,当前loss为:683.9954223632812
当前代数:80,当前loss为:683.468994140625
当前代数:81,当前loss为:650.9133911132812
当前代数:82,当前loss为:545.8726806640625
当前代数:83,当前loss为:543.443359375
当前代数:84,当前loss为:540.5396118164062
当前代数:85,当前loss为:342.5977478027344
当前代数:86,当前loss为:645.627685546875
当前代数:87,当前loss为:644.410888671875
当前代数:88,当前loss为:642.9930419921875
当前代数:89,当前loss为:641.402099609375
当前代数:90,当前loss为:524.3102416992188
当前代数:91,当前loss为:521.52880859375
当前代数:92,当前loss为:518.2946166992188
当前代数:93,当前loss为:676.001220703125
当前代数:94,当前loss为:633.2501220703125
当前代数:95,当前loss为:674.614501953125
当前代数:96,当前loss为:673.8367919921875
当前代数:97,当前loss为:628.274169921875
当前代数:98,当前loss为:626.4828491210938
当前代数:99,当前loss为:325.72052001953125
StriveZs wechat
Hobby lead  creation, technology change world.