参考链接:https://www.bilibili.com/video/BV1hE411t7RN/?spm_id_from=333.337.search-card.all.click&vd_source=e01172ea292c1c605b346101d7006c61
# 一、直接搭建
| import torch |
| from torch import nn |
| from torch.nn import Conv2d, MaxPool2d, Flatten, Linear |
| |
| |
| class SelfNet(nn.Module): |
| def __init__(self): |
| super(SelfNet, self).__init__() |
| self.conv1 = Conv2d(3, 32, 5, padding=2) |
| self.maxpool1 = MaxPool2d(2) |
| self.conv2 = Conv2d(32, 32, 5, padding=2) |
| self.maxpool2 = MaxPool2d(2) |
| self.conv3 = Conv2d(32, 64, 5, padding=2) |
| self.maxpool3 = MaxPool2d(2) |
| self.flatten = Flatten() |
| self.linear1 = Linear(1024, 64) |
| self.linear2 = Linear(64, 10) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.maxpool1(x) |
| x = self.conv2(x) |
| x = self.maxpool2(x) |
| x = self.conv3(x) |
| x = self.maxpool3(x) |
| x = self.flatten(x) |
| x = self.linear1(x) |
| x = self.linear2(x) |
| return x |
| |
| |
| selfNet = SelfNet() |
| print(selfNet) |
| input = torch.ones((64, 3, 32, 32)) |
| output = selfNet(input) |
| print(output.shape) |
# 二、使用直接搭建
| import torch |
| from torch import nn |
| from torch.nn import Conv2d, MaxPool2d, Flatten, Linear |
| |
| |
| class SelfNet(nn.Module): |
| def __init__(self): |
| super(SelfNet, self).__init__() |
| self.model = nn.Sequential( |
| Conv2d(3, 32, 5, padding=2), |
| MaxPool2d(2), |
| Conv2d(32, 32, 5, padding=2), |
| MaxPool2d(2), |
| Conv2d(32, 64, 5, padding=2), |
| MaxPool2d(2), |
| Flatten(), |
| Linear(1024, 64), |
| Linear(64, 10) |
| ) |
| |
| def forward(self, x): |
| x = self.model(x) |
| return x |
| |
| |
| selfNet = SelfNet() |
| print(selfNet) |
| input = torch.ones((64, 3, 32, 32)) |
| output = selfNet(input) |
| print(output.shape) |