TL;DR: Vision Transformers (ViT) divide an image into patches, embed them, and feed the sequence to a Transformer encoder. They excel with large-scale pretraining and provide strong global context modeling for vision tasks.
How ViT works — a quick overview
- Patchify: split image into fixed-size patches (e.g., 16×16) and flatten each patch.
 - Linear projection: map each flattened patch to a D-dimensional embedding.
 - Positional encoding: add position embeddings so the model knows patch positions.
 - Transformer encoder: process the sequence of patch embeddings with multi-head self-attention + MLP blocks.
 - Classification token: prepend a cls token whose final embedding is used for classification.
 
Why ViT?
| Strengths | Limitations | 
|---|---|
| Excellent at modeling global relationships; scalable with data; unified architecture across vision and NLP. | Requires large pretraining datasets; attention scales quadratically with patches (compute heavy); less sample-efficient on small datasets. | 
Minimal PyTorch example — Patchify + Transformer Encoder
This example shows the core idea: converting an image into patches and running a small Transformer encoder. It is educational — not optimized for production.
# Requirements: torch, torchvision
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
# 1) Patchify helper
def image_to_patches(img_tensor, patch_size):
    # img_tensor: (C, H, W)
    C, H, W = img_tensor.shape
    assert H % patch_size == 0 and W % patch_size == 0
    patches = img_tensor.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
    # patches shape: (C, num_h, num_w, patch_size, patch_size)
    patches = patches.permute(1,2,0,3,4).contiguous()
    num_patches = patches.shape[0] * patches.shape[1]
    patches = patches.view(num_patches, -1)  # (num_patches, C*patch_size*patch_size)
    return patches
# 2) Simple ViT-like module
class SimpleViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, emb_dim=128, num_heads=4, depth=6, num_classes=1000):
        super().__init__()
        self.patch_size = patch_size
        num_patches = (img_size // patch_size) ** 2
        self.patch_dim = 3 * patch_size * patch_size
        self.to_patch_emb = nn.Linear(self.patch_dim, emb_dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_dim))
        self.pos_emb = nn.Parameter(torch.randn(1, num_patches + 1, emb_dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=emb_dim*4)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.mlp_head = nn.Sequential(nn.LayerNorm(emb_dim), nn.Linear(emb_dim, num_classes))
    def forward(self, x):
        # x: (B, C, H, W)
        B = x.shape[0]
        patches = []
        for i in range(B):
            p = image_to_patches(x[i], self.patch_size)  # (num_patches, patch_dim)
            patches.append(p)
        patches = torch.stack(patches)  # (B, num_patches, patch_dim)
        patches = self.to_patch_emb(patches)  # (B, num_patches, emb_dim)
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B,1,emb_dim)
        x = torch.cat([cls_tokens, patches], dim=1)  # (B, num_patches+1, emb_dim)
        x = x + self.pos_emb
        # Transformer expects (sequence, batch, emb) if using legacy, but PyTorch's nn.TransformerEncoder assumes (S,B,E)
        x = x.permute(1,0,2)  # (S, B, E)
        x = self.transformer(x)
        x = x.permute(1,0,2)  # (B, S, E)
        cls_final = x[:,0]  # (B, emb_dim)
        return self.mlp_head(cls_final)
# 3) Quick run (random image)
if __name__ == '__main__':
    img = Image.new('RGB', (224,224), color='white')
    tf = T.Compose([T.ToTensor()])
    t = tf(img).unsqueeze(0)  # (1,3,224,224)
    model = SimpleViT(img_size=224, patch_size=16, emb_dim=128, num_heads=4, depth=4, num_classes=10)
    logits = model(t)
    print('logits', logits.shape)
    Using a pretrained ViT (recommended for real tasks)
Use libraries like timm or torchvision's pretrained models to leverage large-scale pretraining.
# Requirements: timm, torch, torchvision
# pip install timm
import timm
import torch
from torchvision import transforms
from PIL import Image
model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.eval()
# Preprocess
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225))
])
img = Image.open('img/example.jpg').convert('RGB')
input_tensor = preprocess(img).unsqueeze(0)  # (1,3,224,224)
with torch.no_grad():
    logits = model(input_tensor)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    top5 = torch.topk(probs, k=5)
    print('Top5 indices:', top5.indices)
    Notes: 
      
    - ViT models typically require pretrained weights and large datasets for best results. Use transfer learning (fine-tune) for smaller datasets.
 - Consider efficient variants (Swin Transformer, DeiT) for resource-constrained environments.
 - The minimal example above illustrates the idea — production implementations rely on optimized libraries.
 
Resources & further reading
- Original ViT paper: Dosovitskiy et al., 2020 — “An Image is Worth 16x16 Words”
 - DeiT (Data-efficient Image Transformers)
 - Swin Transformer — hierarchical and efficient ViT variant
 - timm library — many pretrained vision models