“Talk is cheap, show me the code”
In [1]:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data.dataloader import DataLoader
from torchvision.transforms import transforms
from torchvision.datasets import MNIST
from torchsummary import summary
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
In [2]:
# use gpu if we can
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
load data¶
In [3]:
BATCH_SIZE = 64
train_data = MNIST(
root='./data',
train=True,
download=True,
transform=transforms.ToTensor()
)
val_data = MNIST(
root='./data',
train=False,
transform=transforms.ToTensor()
)
train_loader = DataLoader(
dataset=train_data,
batch_size=BATCH_SIZE,
shuffle=True
)
val_loader = DataLoader(
dataset=val_data,
batch_size=BATCH_SIZE
)
In [4]:
for step, (batch_x, batch_y) in enumerate(train_loader):
if step > 0:
break
batch_x = batch_x.squeeze().numpy()
batch_y = batch_y.numpy()
class_label = train_data.classes
plt.figure(figsize=(12, 5))
for i in np.arange(len(batch_y)):
plt.subplot(4, 16, i+1)
plt.imshow(batch_x[i,:,:], cmap=plt.cm.gray)
plt.title(class_label[batch_y[i]], size=9)
plt.axis('off')
plt.subplots_adjust(wspace=0.05)
In [5]:
img_x, img_y = train_data[0]
img_x.shape
Out[5]:
torch.Size([1, 28, 28])
Model¶
In [6]:
NUM_CLASSES = 10
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 8, 3, 1, 1)
self.conv2 = nn.Conv2d(8, 16, 3, 1, 1)
self.conv3 = nn.Conv2d(16, 32, 3, 1, 1)
self.conv4 = nn.Conv2d(32, 64, 3, 1, 1)
self.view_features = 64* 7 * 7
self.fc1 = nn.Linear(self.view_features, 256)
self.fc2 = nn.Linear(256, 64)
self.fc3 = nn.Linear(64, NUM_CLASSES)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, self.view_features)
x = F.dropout(x, 0.5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return F.softmax(x, 1)
In [7]:
model = CNN().to(device)
summary(model, input_size=(1, 28, 28))
---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [-1, 8, 28, 28] 80 Conv2d-2 [-1, 16, 28, 28] 1,168 Conv2d-3 [-1, 32, 14, 14] 4,640 Conv2d-4 [-1, 64, 14, 14] 18,496 Linear-5 [-1, 256] 803,072 Linear-6 [-1, 64] 16,448 Linear-7 [-1, 10] 650 ================================================================ Total params: 844,554 Trainable params: 844,554 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.00 Forward/backward pass size (MB): 0.29 Params size (MB): 3.22 Estimated Total Size (MB): 3.51 ----------------------------------------------------------------
train¶
In [8]:
%%time
NUM_EPOCH = 50
PATIENCE = 4
patience_counter = 0
best_val_loss = float('inf')
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-5)
history_train_loss = []
history_train_acc = []
history_val_loss = []
history_val_acc = []
for epoch in range(NUM_EPOCH):
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
for inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCH} [train]'):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs.data, 1)
train_total += labels.size(0)
train_correct += (predicted == labels).sum().item()
train_loss = train_loss / len(train_loader.dataset)
train_acc = 100 * train_correct / train_total
history_train_loss.append(train_loss)
history_train_acc.append(train_acc)
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for inputs, labels in tqdm(val_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCH} [Val]'):
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs.data, 1)
val_total += labels.size(0)
val_correct += (predicted == labels).sum().item()
val_loss = val_loss / len(val_loader.dataset)
val_acc = 100 * val_correct / val_total
history_val_loss.append(val_loss)
history_val_acc.append(val_acc)
print(
f'Epoch {epoch+1}/{NUM_EPOCH}: '
f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
f'Val Loss: {val_loss:.4f}, Train Acc: {val_acc:.2f}%, '
)
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
torch.save(model.state_dict(), "best_loss.pth")
else:
patience_counter += 1
print(f'Early stopping counter: {patience_counter}/{PATIENCE}')
if patience_counter >= PATIENCE:
print(f'Early stopping triggered at epoch {epoch+1}! Best loss: {best_val_loss}')
break
scheduler.step()
Epoch 1/50 [train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:03<00:00, 287.81it/s] Epoch 1/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:00<00:00, 462.33it/s]
Epoch 1/50: Train Loss: 1.7030, Train Acc: 75.92%, Val Loss: 1.5147, Train Acc: 94.72%, Early stopping counter: 0/4
Epoch 2/50 [train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:03<00:00, 288.31it/s] Epoch 2/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:00<00:00, 447.93it/s]
Epoch 2/50: Train Loss: 1.5039, Train Acc: 95.73%, Val Loss: 1.4892, Train Acc: 97.24%, Early stopping counter: 0/4
Epoch 3/50 [train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:03<00:00, 282.15it/s] Epoch 3/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:00<00:00, 462.55it/s]
Epoch 3/50: Train Loss: 1.4946, Train Acc: 96.66%, Val Loss: 1.4883, Train Acc: 97.32%, Early stopping counter: 0/4
Epoch 4/50 [train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:03<00:00, 279.26it/s] Epoch 4/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:00<00:00, 445.06it/s]
Epoch 4/50: Train Loss: 1.4897, Train Acc: 97.14%, Val Loss: 1.4872, Train Acc: 97.36%, Early stopping counter: 0/4
Epoch 5/50 [train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:03<00:00, 281.26it/s] Epoch 5/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:00<00:00, 457.82it/s]
Epoch 5/50: Train Loss: 1.4876, Train Acc: 97.34%, Val Loss: 1.4839, Train Acc: 97.71%, Early stopping counter: 0/4
Epoch 6/50 [train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:03<00:00, 279.12it/s] Epoch 6/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:00<00:00, 453.87it/s]
Epoch 6/50: Train Loss: 1.4846, Train Acc: 97.66%, Val Loss: 1.4830, Train Acc: 97.81%, Early stopping counter: 0/4
Epoch 7/50 [train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:03<00:00, 275.70it/s] Epoch 7/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:00<00:00, 397.92it/s]
Epoch 7/50: Train Loss: 1.4849, Train Acc: 97.62%, Val Loss: 1.4825, Train Acc: 97.87%, Early stopping counter: 0/4
Epoch 8/50 [train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:03<00:00, 281.42it/s] Epoch 8/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:00<00:00, 465.34it/s]
Epoch 8/50: Train Loss: 1.4821, Train Acc: 97.90%, Val Loss: 1.4841, Train Acc: 97.69%, Early stopping counter: 1/4
Epoch 9/50 [train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:03<00:00, 291.98it/s] Epoch 9/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:00<00:00, 460.54it/s]
Epoch 9/50: Train Loss: 1.4806, Train Acc: 98.05%, Val Loss: 1.4839, Train Acc: 97.69%, Early stopping counter: 2/4
Epoch 10/50 [train]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:03<00:00, 294.92it/s] Epoch 10/50 [Val]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:00<00:00, 446.51it/s]
Epoch 10/50: Train Loss: 1.4798, Train Acc: 98.13%, Val Loss: 1.4800, Train Acc: 98.10%, Early stopping counter: 0/4
Epoch 11/50 [train]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:03<00:00, 284.85it/s] Epoch 11/50 [Val]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:00<00:00, 418.88it/s]
Epoch 11/50: Train Loss: 1.4777, Train Acc: 98.34%, Val Loss: 1.4772, Train Acc: 98.35%, Early stopping counter: 0/4
Epoch 12/50 [train]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:03<00:00, 284.02it/s] Epoch 12/50 [Val]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:00<00:00, 455.68it/s]
Epoch 12/50: Train Loss: 1.4765, Train Acc: 98.45%, Val Loss: 1.4747, Train Acc: 98.63%, Early stopping counter: 0/4
Epoch 13/50 [train]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:03<00:00, 286.78it/s] Epoch 13/50 [Val]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:00<00:00, 463.43it/s]
Epoch 13/50: Train Loss: 1.4756, Train Acc: 98.56%, Val Loss: 1.4752, Train Acc: 98.58%, Early stopping counter: 1/4
Epoch 14/50 [train]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:03<00:00, 294.38it/s] Epoch 14/50 [Val]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:00<00:00, 457.06it/s]
Epoch 14/50: Train Loss: 1.4748, Train Acc: 98.62%, Val Loss: 1.4749, Train Acc: 98.59%, Early stopping counter: 2/4
Epoch 15/50 [train]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:03<00:00, 292.61it/s] Epoch 15/50 [Val]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:00<00:00, 458.95it/s]
Epoch 15/50: Train Loss: 1.4730, Train Acc: 98.81%, Val Loss: 1.4725, Train Acc: 98.83%, Early stopping counter: 0/4
Epoch 16/50 [train]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:03<00:00, 292.37it/s] Epoch 16/50 [Val]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:00<00:00, 463.87it/s]
Epoch 16/50: Train Loss: 1.4722, Train Acc: 98.90%, Val Loss: 1.4725, Train Acc: 98.87%, Early stopping counter: 1/4
Epoch 17/50 [train]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:03<00:00, 285.26it/s] Epoch 17/50 [Val]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:00<00:00, 447.51it/s]
Epoch 17/50: Train Loss: 1.4721, Train Acc: 98.91%, Val Loss: 1.4730, Train Acc: 98.81%, Early stopping counter: 2/4
Epoch 18/50 [train]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:03<00:00, 288.98it/s] Epoch 18/50 [Val]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:00<00:00, 466.31it/s]
Epoch 18/50: Train Loss: 1.4714, Train Acc: 98.97%, Val Loss: 1.4736, Train Acc: 98.76%, Early stopping counter: 3/4
Epoch 19/50 [train]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:03<00:00, 293.37it/s] Epoch 19/50 [Val]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:00<00:00, 465.71it/s]
Epoch 19/50: Train Loss: 1.4709, Train Acc: 99.02%, Val Loss: 1.4733, Train Acc: 98.82%, Early stopping counter: 4/4 Early stopping triggered at epoch 19! Best loss: 1.4725148626327516 CPU times: user 1min 13s, sys: 632 ms, total: 1min 13s Wall time: 1min 8s
In [9]:
epochs = range(1, len(history_train_loss) + 1)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs, history_train_loss, 'b', label='Trainning Loss')
plt.plot(epochs, history_val_loss, 'r', label='Val loss')
plt.title('Loss Curve')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.subplot(1, 2, 2)
plt.plot(epochs, history_train_acc, 'b', label='Trainning ACC')
plt.plot(epochs, history_val_acc, 'r', label='Val ACC')
plt.title('ACC Curve')
plt.xlabel('Epochs')
plt.ylabel('ACC')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
Comments NOTHING