TL;DR: Vision Transformers (ViT) divide an image into patches, embed them, and feed the sequence to a Transformer encoder. They excel with large-scale pretraining and provide strong global context modeling for vision tasks.
How ViT works — a quick overview
- Patchify: split image into fixed-size patches (e.g., 16×16) and flatten each patch.
- Linear projection: map each flattened patch to a D-dimensional embedding.
- Positional encoding: add position embeddings so the model knows patch positions.
- Transformer encoder: process the sequence of patch embeddings with multi-head self-attention + MLP blocks.
- Classification token: prepend a cls token whose final embedding is used for classification.
Why ViT?
Strengths | Limitations |
---|---|
Excellent at modeling global relationships; scalable with data; unified architecture across vision and NLP. | Requires large pretraining datasets; attention scales quadratically with patches (compute heavy); less sample-efficient on small datasets. |
Minimal PyTorch example — Patchify + Transformer Encoder
This example shows the core idea: converting an image into patches and running a small Transformer encoder. It is educational — not optimized for production.
# Requirements: torch, torchvision
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
# 1) Patchify helper
def image_to_patches(img_tensor, patch_size):
# img_tensor: (C, H, W)
C, H, W = img_tensor.shape
assert H % patch_size == 0 and W % patch_size == 0
patches = img_tensor.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
# patches shape: (C, num_h, num_w, patch_size, patch_size)
patches = patches.permute(1,2,0,3,4).contiguous()
num_patches = patches.shape[0] * patches.shape[1]
patches = patches.view(num_patches, -1) # (num_patches, C*patch_size*patch_size)
return patches
# 2) Simple ViT-like module
class SimpleViT(nn.Module):
def __init__(self, img_size=224, patch_size=16, emb_dim=128, num_heads=4, depth=6, num_classes=1000):
super().__init__()
self.patch_size = patch_size
num_patches = (img_size // patch_size) ** 2
self.patch_dim = 3 * patch_size * patch_size
self.to_patch_emb = nn.Linear(self.patch_dim, emb_dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, emb_dim))
self.pos_emb = nn.Parameter(torch.randn(1, num_patches + 1, emb_dim))
encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=emb_dim*4)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
self.mlp_head = nn.Sequential(nn.LayerNorm(emb_dim), nn.Linear(emb_dim, num_classes))
def forward(self, x):
# x: (B, C, H, W)
B = x.shape[0]
patches = []
for i in range(B):
p = image_to_patches(x[i], self.patch_size) # (num_patches, patch_dim)
patches.append(p)
patches = torch.stack(patches) # (B, num_patches, patch_dim)
patches = self.to_patch_emb(patches) # (B, num_patches, emb_dim)
cls_tokens = self.cls_token.expand(B, -1, -1) # (B,1,emb_dim)
x = torch.cat([cls_tokens, patches], dim=1) # (B, num_patches+1, emb_dim)
x = x + self.pos_emb
# Transformer expects (sequence, batch, emb) if using legacy, but PyTorch's nn.TransformerEncoder assumes (S,B,E)
x = x.permute(1,0,2) # (S, B, E)
x = self.transformer(x)
x = x.permute(1,0,2) # (B, S, E)
cls_final = x[:,0] # (B, emb_dim)
return self.mlp_head(cls_final)
# 3) Quick run (random image)
if __name__ == '__main__':
img = Image.new('RGB', (224,224), color='white')
tf = T.Compose([T.ToTensor()])
t = tf(img).unsqueeze(0) # (1,3,224,224)
model = SimpleViT(img_size=224, patch_size=16, emb_dim=128, num_heads=4, depth=4, num_classes=10)
logits = model(t)
print('logits', logits.shape)
Using a pretrained ViT (recommended for real tasks)
Use libraries like timm
or torchvision's pretrained models to leverage large-scale pretraining.
# Requirements: timm, torch, torchvision
# pip install timm
import timm
import torch
from torchvision import transforms
from PIL import Image
model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.eval()
# Preprocess
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225))
])
img = Image.open('img/example.jpg').convert('RGB')
input_tensor = preprocess(img).unsqueeze(0) # (1,3,224,224)
with torch.no_grad():
logits = model(input_tensor)
probs = torch.nn.functional.softmax(logits, dim=-1)
top5 = torch.topk(probs, k=5)
print('Top5 indices:', top5.indices)
Notes:
- ViT models typically require pretrained weights and large datasets for best results. Use transfer learning (fine-tune) for smaller datasets.
- Consider efficient variants (Swin Transformer, DeiT) for resource-constrained environments.
- The minimal example above illustrates the idea — production implementations rely on optimized libraries.
Resources & further reading
- Original ViT paper: Dosovitskiy et al., 2020 — “An Image is Worth 16x16 Words”
- DeiT (Data-efficient Image Transformers)
- Swin Transformer — hierarchical and efficient ViT variant
- timm library — many pretrained vision models