Vision Transformers (ViT) — Explained + Sample Code

Treating images as sequences: how ViT works, pros/cons, and practical PyTorch snippets you can run.

TL;DR: Vision Transformers (ViT) divide an image into patches, embed them, and feed the sequence to a Transformer encoder. They excel with large-scale pretraining and provide strong global context modeling for vision tasks.

How ViT works — a quick overview

  1. Patchify: split image into fixed-size patches (e.g., 16×16) and flatten each patch.
  2. Linear projection: map each flattened patch to a D-dimensional embedding.
  3. Positional encoding: add position embeddings so the model knows patch positions.
  4. Transformer encoder: process the sequence of patch embeddings with multi-head self-attention + MLP blocks.
  5. Classification token: prepend a cls token whose final embedding is used for classification.

Why ViT?

StrengthsLimitations
Excellent at modeling global relationships; scalable with data; unified architecture across vision and NLP.Requires large pretraining datasets; attention scales quadratically with patches (compute heavy); less sample-efficient on small datasets.

Minimal PyTorch example — Patchify + Transformer Encoder

This example shows the core idea: converting an image into patches and running a small Transformer encoder. It is educational — not optimized for production.

# Requirements: torch, torchvision
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image

# 1) Patchify helper
def image_to_patches(img_tensor, patch_size):
    # img_tensor: (C, H, W)
    C, H, W = img_tensor.shape
    assert H % patch_size == 0 and W % patch_size == 0
    patches = img_tensor.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
    # patches shape: (C, num_h, num_w, patch_size, patch_size)
    patches = patches.permute(1,2,0,3,4).contiguous()
    num_patches = patches.shape[0] * patches.shape[1]
    patches = patches.view(num_patches, -1)  # (num_patches, C*patch_size*patch_size)
    return patches

# 2) Simple ViT-like module
class SimpleViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, emb_dim=128, num_heads=4, depth=6, num_classes=1000):
        super().__init__()
        self.patch_size = patch_size
        num_patches = (img_size // patch_size) ** 2
        self.patch_dim = 3 * patch_size * patch_size
        self.to_patch_emb = nn.Linear(self.patch_dim, emb_dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_dim))
        self.pos_emb = nn.Parameter(torch.randn(1, num_patches + 1, emb_dim))

        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=emb_dim*4)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.mlp_head = nn.Sequential(nn.LayerNorm(emb_dim), nn.Linear(emb_dim, num_classes))

    def forward(self, x):
        # x: (B, C, H, W)
        B = x.shape[0]
        patches = []
        for i in range(B):
            p = image_to_patches(x[i], self.patch_size)  # (num_patches, patch_dim)
            patches.append(p)
        patches = torch.stack(patches)  # (B, num_patches, patch_dim)
        patches = self.to_patch_emb(patches)  # (B, num_patches, emb_dim)

        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B,1,emb_dim)
        x = torch.cat([cls_tokens, patches], dim=1)  # (B, num_patches+1, emb_dim)
        x = x + self.pos_emb

        # Transformer expects (sequence, batch, emb) if using legacy, but PyTorch's nn.TransformerEncoder assumes (S,B,E)
        x = x.permute(1,0,2)  # (S, B, E)
        x = self.transformer(x)
        x = x.permute(1,0,2)  # (B, S, E)

        cls_final = x[:,0]  # (B, emb_dim)
        return self.mlp_head(cls_final)

# 3) Quick run (random image)
if __name__ == '__main__':
    img = Image.new('RGB', (224,224), color='white')
    tf = T.Compose([T.ToTensor()])
    t = tf(img).unsqueeze(0)  # (1,3,224,224)
    model = SimpleViT(img_size=224, patch_size=16, emb_dim=128, num_heads=4, depth=4, num_classes=10)
    logits = model(t)
    print('logits', logits.shape)

Using a pretrained ViT (recommended for real tasks)

Use libraries like timm or torchvision's pretrained models to leverage large-scale pretraining.

# Requirements: timm, torch, torchvision
# pip install timm
import timm
import torch
from torchvision import transforms
from PIL import Image

model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.eval()

# Preprocess
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225))
])
img = Image.open('img/example.jpg').convert('RGB')
input_tensor = preprocess(img).unsqueeze(0)  # (1,3,224,224)

with torch.no_grad():
    logits = model(input_tensor)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    top5 = torch.topk(probs, k=5)
    print('Top5 indices:', top5.indices)
Notes:

Resources & further reading

← Back to Blog Index