Escrevendo ResNet do zero em PyTorch
Neste artigo, construiremos o ResNet, um grande avanço na visão computacional, que resolveu o problema de degradação do desempenho da rede se a rede for muito profunda. Também introduziu o conceito de Conexões Residuais (mais sobre isso mais tarde).
Começaremos examinando a arquitetura e a intuição por trás de como o ResNet funciona. Em seguida, iremos compará-lo com o VGG e examinar como ele resolve alguns dos problemas que o VGG teve. Então, como antes, carregaremos nosso conjunto de dados, CIFAR10, e o pré-processaremos para deixá-lo pronto para modelagem. Então, primeiro implementaremos o bloco de construção básico de um ResNet (chamaremos isso de ResidualBlock) e usaremos isso para construir nossa rede. Em seguida, esta rede será treinada nos dados pré-processados e, por fim, veremos o desempenho do modelo treinado em dados não vistos (conjunto de testes).
Pré-requisitos
Para acompanhar este artigo, você precisará de experiência básica com código Python e um conhecimento básico de Deep Learning. Operaremos sob a suposição de que todos os leitores tenham acesso a máquinas suficientemente poderosas para que possam executar o código fornecido. GPUs menos potentes também podem ser usadas, mas os resultados podem demorar mais para serem alcançados.
Caso você não tenha acesso a uma GPU, sugerimos acessá-la pela nuvem. Existem muitos provedores de nuvem que oferecem GPUs. Os Droplets de GPU DigitalOcean estão atualmente em Disponibilidade Antecipada, saiba mais e inscreva-se para obter interesse em Droplets de GPU aqui
Para obter instruções sobre como começar a usar o código Python, recomendamos experimentar este guia para iniciantes para configurar seu sistema e se preparar para executar tutoriais para iniciantes.
ResNet
Uma das desvantagens do VGG era que ele não conseguia ir tão fundo quanto desejado porque começou a perder a capacidade de generalização (ou seja, começou a se ajustar demais). Isso ocorre porque à medida que uma rede neural se aprofunda, os gradientes da função de perda começam a diminuir para zero e, portanto, os pesos não são atualizados. Este problema é conhecido como problema do gradiente evanescente. ResNet essencialmente resolveu esse problema usando conexões saltadas.
Um bloco residual. Fonte: Artigo ResNet
Na figura acima podemos observar que, além das conexões normais, existe uma conexão direta que pula algumas camadas do modelo (skip connection). Com a conexão de salto, a saída muda de h(x)=f(wx +b) para h(x)=f(x) + x. Essas conexões de salto ajudam, pois permitem um caminho de atalho alternativo para o fluxo dos gradientes. Abaixo está a arquitetura do ResNet de 34 camadas.
Fonte: Artigo ResNet
Carregamento de dados: o conjunto de dados
Neste artigo, usaremos o famoso conjunto de dados CIFAR-10, que se tornou uma das escolhas mais comuns para conjuntos de dados de visão computacional para iniciantes. O conjunto de dados é um subconjunto rotulado do conjunto de dados de 80 milhões de imagens minúsculas. Eles foram coletados por Alex Krizhevsky, Vinod Nair e Geoffrey Hinton. O conjunto de dados CIFAR-10 consiste em 60.000 imagens coloridas 32x32 em 10 classes, com 6.000 imagens por classe. Existem 50.000 imagens de treinamento e 10.000 imagens de teste.
O conjunto de dados é dividido em cinco lotes de treinamento e um lote de teste, cada um com 10.000 imagens. O lote de teste contém exatamente 1.000 imagens selecionadas aleatoriamente de cada classe. Os lotes de treinamento contêm as imagens restantes em ordem aleatória, mas alguns lotes de treinamento podem conter mais imagens de uma classe do que de outra. Entre eles, os lotes de treinamento contêm exatamente 5.000 imagens de cada turma. As aulas são completamente mutuamente exclusivas. Não há sobreposição entre automóveis e caminhões. “Automóvel” inclui sedãs, SUVs e coisas desse tipo. “Caminhão” inclui apenas caminhões grandes. Nenhum dos dois inclui picapes.
Aqui estão as classes no conjunto de dados, bem como 10 imagens aleatórias de cada:
Fonte: papel
Importando as Bibliotecas
Começaremos importando as bibliotecas que usaríamos. Além disso, garantiremos que o Notebook use a GPU para treinar o modelo, caso esteja disponível
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Importando as bibliotecas
Carregando o conjunto de dados
Agora passamos a carregar nosso conjunto de dados. Para isso, usaremos a biblioteca torchvision
que não só fornece acesso rápido a centenas de conjuntos de dados de visão computacional, mas também métodos fáceis e intuitivos para pré-processá-los/transformá-los para que estejam prontos para modelagem.
Começamos definindo nossa função
data_loader
que retorna os dados de treinamento ou teste dependendo dos argumentosÉ sempre uma boa prática normalizar nossos dados em projetos de Deep Learning, pois isso torna o treinamento mais rápido e fácil de convergir. Para isso, definimos a variável
normalize
com a média e os desvios padrão de cada canal (vermelho, verde e azul) do conjunto de dados. Estes podem ser calculados manualmente, mas também estão disponíveis online. Isso é usado na variáveltransform
onde redimensionamos os dados, os convertemos em tensores e depois os normalizamosFazemos uso de carregadores de dados. Os carregadores de dados nos permitem iterar os dados em lotes, e os dados são carregados durante a iteração e não todos de uma vez no início em nossa RAM. Isso é muito útil se estivermos lidando com grandes conjuntos de dados de cerca de milhões de imagens.
Dependendo do argumento
test
, carregamos a divisão do trem (iftest=False
) ou otest
( iftest=True
) divisão. No caso de trem, a divisão é dividida aleatoriamente em trem e conjunto de validação (0,9:0,1).def data_loader(data_dir, batch_size, random_seed=42, valid_size=0.1, shuffle=True, test=False): normalize = transforms.Normalize( mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010], ) # define transforms transform = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor(), normalize, ]) if test: dataset = datasets.CIFAR10( root=data_dir, train=False, download=True, transform=transform, ) data_loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=shuffle ) return data_loader # load the dataset train_dataset = datasets.CIFAR10( root=data_dir, train=True, download=True, transform=transform, ) valid_dataset = datasets.CIFAR10( root=data_dir, train=True, download=True, transform=transform, ) num_train = len(train_dataset) indices = list(range(num_train)) split = int(np.floor(valid_size * num_train)) if shuffle: np.random.seed(42) np.random.shuffle(indices) train_idx, valid_idx = indices[split:], indices[:split] train_sampler = SubsetRandomSampler(train_idx) valid_sampler = SubsetRandomSampler(valid_idx) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, sampler=train_sampler) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=batch_size, sampler=valid_sampler) return (train_loader, valid_loader) train_loader, valid_loader = data_loader(data_dir='./data', batch_size=64) test_loader = data_loader(data_dir='./data', batch_size=64, test=True)
ResNet from Scratch: como os modelos funcionam no PyTorch
Antes de prosseguirmos para a construção do bloco residual e do ResNet, primeiro examinaríamos e entenderíamos como as redes neurais são definidas no PyTorch:
nn.Module
fornece um padrão para a criação de modelos personalizados junto com algumas funcionalidades necessárias que ajudam no treinamento. É por isso que todo modelo personalizado tende a herdar denn.Module
- Depois, há duas funções principais dentro de cada modelo personalizado. A primeira é a função de inicialização,
__init__
, onde definimos as várias camadas que usaremos, e a segunda é a funçãoforward
, que define a sequência na qual as camadas acima serão ser executado em uma determinada entrada
Camadas no PyTorch
Chegando agora aos diferentes tipos de camadas disponíveis no PyTorch que são úteis para nós:
nn.Conv2d
: Estas são as camadas convolucionais que aceitam o número de canais de entrada e saída como argumentos, junto com o tamanho do kernel para o filtro. Ele também aceita quaisquer avanços ou preenchimentos se quisermos aplicá-los.nn.BatchNorm2d
: aplica a normalização em lote à saída da camada convolucionalnn.ReLU
: este é um tipo de função de ativação aplicada a várias saídas na redenn.MaxPool2d
: aplica o pool máximo à saída com o tamanho do kernel fornecidonn.Dropout
: Isso é usado para aplicar dropout à saída com uma determinada probabilidadenn.Linear
: Esta é basicamente uma camada totalmente conectadann.Sequential
: tecnicamente não é um tipo de camada, mas ajuda a combinar diferentes operações que fazem parte da mesma etapa
Bloco Residual
Antes de começar com a rede, precisamos construir um ResidualBlock que possamos reutilizar em toda a rede. O bloco (conforme mostrado na arquitetura) contém uma conexão de salto que é um parâmetro opcional ( downsample
). Observe que no forward
, isso é aplicado diretamente à entrada, x
, e não à saída, out
.
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
nn.BatchNorm2d(out_channels),
nn.ReLU())
self.conv2 = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(out_channels))
self.downsample = downsample
self.relu = nn.ReLU()
self.out_channels = out_channels
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.conv2(out)
if self.downsample:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
Bloco Residual
ResNet34
Agora que criamos o ResidualBlock, podemos construir nosso ResNet.
Observe que existem três blocos na arquitetura, contendo 3, 3, 6 e 3 camadas respectivamente. Para fazer este bloco, criamos uma função auxiliar _make_layer
. A função adiciona as camadas uma por uma junto com o Bloco Residual. Após os blocos, adicionamos o pooling médio e a camada linear final.
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes = 10):
super(ResNet, self).__init__()
self.inplanes = 64
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3),
nn.BatchNorm2d(64),
nn.ReLU())
self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
self.layer0 = self._make_layer(block, 64, layers[0], stride = 1)
self.layer1 = self._make_layer(block, 128, layers[1], stride = 2)
self.layer2 = self._make_layer(block, 256, layers[2], stride = 2)
self.layer3 = self._make_layer(block, 512, layers[3], stride = 2)
self.avgpool = nn.AvgPool2d(7, stride=1)
self.fc = nn.Linear(512, num_classes)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride),
nn.BatchNorm2d(planes),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
Configurando hiperparâmetros
É sempre recomendável experimentar valores diferentes para vários hiperparâmetros em nosso modelo, mas aqui usaremos apenas uma configuração. Independentemente disso, recomendamos que todos experimentem diferentes e vejam qual funciona melhor. Os hiperparâmetros incluem a definição do número de épocas, tamanho do lote, taxa de aprendizado, função de perda junto com o otimizador. Como estamos construindo a variante de 34 camadas do ResNet, precisamos passar também o número apropriado de camadas:
num_classes = 10
num_epochs = 20
batch_size = 16
learning_rate = 0.01
model = ResNet(ResidualBlock, [3, 4, 6, 3]).to(device)
#Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay = 0.001, momentum = 0.9)
#Train the model
total_step = len(train_loader)
Treinamento
Agora, nosso modelo está pronto para treinamento, mas primeiro precisamos saber como funciona o treinamento do modelo no PyTorch:
Começamos carregando as imagens em lotes usando nosso
train_loader
para cada época, e também movemos os dados para a GPU usando a variáveldevice
que definimos anteriormenteO modelo é então usado para prever nos rótulos,
model(images)
, e então calculamos a perda entre as previsões e a verdade básica usando a função de perda definida acima,criterion(outputs, rótulos)
Agora vem a parte do aprendizado, usamos o método de perda para retropropagar,
loss.backward()
, e atualizamos os pesos,optimizer.step()
. Uma coisa importante que é necessária antes de cada atualização é definir os gradientes como zero usandooptimizer.zero_grad()
porque caso contrário, os gradientes serão acumulados (comportamento padrão no PyTorch)Por fim, após cada época, testamos nosso modelo no conjunto de validação, mas, como não precisamos de gradientes na avaliação, podemos desligá-lo usando
with torch.no_grad()
para fazer a avaliação muito mais rápido.import gc total_step = len(train_loader) for epoch in range(num_epochs): for i, (images, labels) in enumerate(train_loader): #Move tensors to the configured device images = images.to(device) labels = labels.to(device) #Forward pass outputs = model(images) loss = criterion(outputs, labels) #Backward and optimize optimizer.zero_grad() loss.backward() optimizer.step() del images, labels, outputs torch.cuda.empty_cache() gc.collect() print ('Epoch [{}/{}], Loss: {:.4f}' .format(epoch+1, num_epochs, loss.item())) #Validation with torch.no_grad(): correct = 0 total = 0 for images, labels in valid_loader: images = images.to(device) labels = labels.to(device) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() del images, labels, outputs print('Accuracy of the network on the {} validation images: {} %'.format(5000, 100 * correct / total))
Analisando a saída do código, podemos ver que o modelo está aprendendo à medida que a perda diminui enquanto a precisão do conjunto de validação aumenta a cada época. Mas podemos notar que ele está flutuando no final, o que pode significar que o modelo está superajustado ou que o batch_size
é pequeno. Teremos que testar para descobrir o que está acontecendo:
Teste
Para teste, usamos exatamente o mesmo código da validação, mas com o test_loader
:
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
del images, labels, outputs
print('Accuracy of the network on the {} test images: {} %'.format(10000, 100 * correct / total))
Usando o código acima e treinando o modelo por 10 épocas, conseguimos uma precisão de 82,87% no conjunto de teste:
Precisão de teste
Conclusão
Vamos agora concluir o que fizemos neste artigo:
- Começamos entendendo a arquitetura e como funciona o ResNet
- Em seguida, carregamos e pré-processamos o conjunto de dados CIFAR10 usando
torchvision
- Em seguida, aprendemos como as definições de modelos personalizados funcionam no PyTorch e os diferentes tipos de camadas disponíveis no
torch
- Construímos nosso ResNet do zero, construindo um ResidualBlock
- Por fim, treinamos e testamos nosso modelo no conjunto de dados CIFAR10, e o modelo pareceu funcionar bem no conjunto de dados de teste com 75% de precisão
Trabalho Futuro
Usando este artigo, obtivemos uma boa introdução e um aprendizado prático, mas podemos aprender muito mais se estendermos isso a outros desafios:
- Tente usar conjuntos de dados diferentes. Um desses conjuntos de dados é o CIFAR100, um subconjunto do conjunto de dados ImageNet, ou o conjunto de dados de 80 milhões de imagens minúsculas
- Experimente diferentes hiperparâmetros e veja a melhor combinação deles para o modelo
- Por fim, tente adicionar ou remover camadas do conjunto de dados para ver seu impacto na capacidade do modelo. Melhor ainda, tente construir a versão ResNet-51 deste modelo