使用pytorch的时候,想要对网络进行可视化,结果报错为:
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
这个是使用GPU 的问题,两种修改方法:
|
1 2 3 |
if __name__ == '__main__': model = fishnet99() torchsummary.summary(model.cuda(), (3, 224, 224)) |
|
1 2 3 |
if __name__ == '__main__': model = fishnet99() torchsummary.summary(model, (3, 224, 224),device='cpu') |
第一种就是 convert your network to cuda,第二种就是 call torchsummary.summary with device='cpu'




最新评论
同样
站长您好,亚马逊云咨询推广资源,望建立联系,可邮件,谢谢。
换友情链接吗?
看你的站做的挺不错的
恭喜!!太强了,硕博连读啊
雁过留毛,人过留名。
看不懂但大受震撼
每天都在战争,希望2026和平.