Pytorch中retain_graph的坑如何解决
本篇内容主要讲解“Pytorch中retain_graph的坑如何解决”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“Pytorch中retain_graph的坑如何解决”吧!
Pytorch中retain_graph的坑
在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用就是
在更新D网络时的loss反向传播过程中使用了retain_graph=True,目的为是为保留该过程中计算的梯度,后续G网络更新时使用;
############################
# (1) Update D network: maximize D(x)-1-D(G(z))
###########################
real_img = Variable(target)
if torch.cuda.is_available():
real_img = real_img.cuda()
z = Variable(data)
if torch.cuda.is_available():
z = z.cuda()
fake_img = netG(z)
netD.zero_grad()
real_out = netD(real_img).mean()
fake_out = netD(fake_img).mean()
d_loss = 1 - real_out + fake_out
d_loss.backward(retain_graph=True) #####
optimizerD.step()
############################
# (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
###########################
netG.zero_grad()
g_loss = generator_criterion(fake_out, fake_img, real_img)
g_loss.backward()
optimizerG.step()
fake_img = netG(z)
fake_out = netD(fake_img).mean()
g_loss = generator_criterion(fake_out, fake_img, real_img)
running_results['g_loss'] += g_loss.data[0] * batch_size
d_loss = 1 - real_out + fake_out
running_results['d_loss'] += d_loss.data[0] * batch_size
running_results['d_score'] += real_out.data[0] * batch_size
running_results['g_score'] += fake_out.data[0] * batch_size
也就是说,只要我们有一个loss,我们就可以先loss.backward(retain_graph=True) 让它先计算梯度,若下面还有其他损失,但是可能你想扩展代码,可能有些loss是不用的,所以先加了 if 等判别语句进行了干预,使用loss.backward(retain_graph=True)就可以单独的计算梯度,屡试不爽。
但是另外一个问题在于,如果你都这么用的话,显存会爆炸,因为他保留了梯度,所以都没有及时释放掉,浪费资源。
而正确的做法应该是,在你最后一个loss 后面,一定要加上loss.backward()这样的形式,也就是让最后一个loss 释放掉之前所有暂时保存下来得梯度!!
Pytorch中有多次backward时需要retain_graph参数
Pytorch中的机制是每次调用loss.backward()时都会free掉计算图中所有缓存的buffers,当模型中可能有多次backward()时,因为前一次调用backward()时已经释放掉了buffer,所以下一次调用时会因为buffers不存在而报错
解决办法
loss.backward(retain_graph=True)
错误使用
-
清空过往梯度;optimizer.zero_grad()
-
反向传播,计算当前梯度;loss1.backward(retain_graph=True)
-
反向传播,计算当前梯度;loss2.backward(retain_graph=True)
-
根据梯度更新网络参数optimizer.step()
因为每次调用bckward时都没有将buffers释放掉,所以会导致内存溢出,迭代越来越慢(因为梯度都保存了,没有free)
正确使用
-
清空过往梯度;optimizer.zero_grad()
-
反向传播,计算当前梯度;loss1.backward(retain_graph=True)
-
反向传播,计算当前梯度;loss2.backward()
-
根据梯度更新网络参数optimizer.step()
最后一个 backward() 不要加 retain_graph 参数,这样每次更新完成后会释放占用的内存,也就不会出现越来越慢的情况了
相关内容
这些是最新的
热门排行
- THINKPHP5+GatewayWorker+Workerman 开发在线客服系统
- 在手机浏览器网页中点击链接跳转到微信界面的方法
- 尊云网站目录系统 ThinkPHP5网站分类目录程序 v2.2.221011
- CentOS 7安装shadowsock(一键安装脚本)
- AdminTemplate 基于LayUI 2.4.5实现的网站后台管理模板
- 用NW.js(node-webkit)开发多平台的桌面客户端
- PHP生成随机昵称/用户名
- THINKPHP5网站分类目录程序 尊云网站目录系统
- 织梦(DEDECMS)微信支付接口 微信插件
- 基于LayUI开发的 网站后台管理模板 BeginnerAdmin
- 响应式后台网站模板 - AMA.ADMIN
- layuiAdmin后台管理模板 Iframe版
- LayUI 1.0.9 升级 至 LayUI 2.1.4 方法
- 简洁清爽的会员中心模板
- jQuery幸运大转盘抽奖活动代码