Pruning transforms exponential backtracking into something practical. Always ask: can I reject this branch without exploring it?
PRUNING STRATEGIES
──────────────────
1. CONSTRAINT PRUNING — skip invalid choices early
if not is_valid(choice): continue
Example: N-Queens skip attacked cells
2. SORTED + SKIP DUPLICATES
nums.sort()
if i > start and nums[i] == nums[i-1]: continue
Example: Combination Sum II, Subsets II
3. REMAINING SUM CHECK — can we still reach target?
if remaining < 0: return # overshot
if remaining < min_choice: return # can't reach
4. BOUND CHECK — is best possible better than known?
if optimistic_bound <= best_so_far: return
Example: branch and bound
5. SYMMETRY BREAKING — avoid symmetric solutions
Fix first element, only explore from there
Example: N-Queens first queen in left half
ORDER: sort choices so pruning triggers earlier.
Try most constrained variable first (MRV).
Store queen positions as board[row] = col. Check column clash and diagonal clash (|col_diff| == row_diff).
def is_valid(board, row, col, n):
for i in range(row):
if board[i] == col: return False # same column
if abs(board[i]-col) == row-i: return False # diagonal
return True
# board[i] = column of queen in row i
Place queens row by row, checking column and diagonal conflicts. Classic backtracking with implicit undo.
def solveNQueens(n):
res = []
board = [-1] * n
def backtrack(row):
if row == n:
res.append(['.'*c + 'Q' + '.'*(n-c-1) for c in board])
return
for col in range(n):
if all(board[i] != col and abs(board[i]-col) != row-i
for i in range(row)):
board[row] = col
backtrack(row + 1)
board[row] = -1
backtrack(0)
return res
Find empty cell, try digits 1-9 with constraint checks (row, col, 3x3 box). Backtrack on failure.
def solveSudoku(board):
def is_valid(r, c, ch):
br, bc = 3*(r//3), 3*(c//3)
for i in range(9):
if board[r][i] == ch: return False
if board[i][c] == ch: return False
if board[br + i//3][bc + i%3] == ch: return False
return True
def solve():
for i in range(9):
for j in range(9):
if board[i][j] == '.':
for ch in '123456789':
if is_valid(i, j, ch):
board[i][j] = ch
if solve(): return True
board[i][j] = '.'
return False
return True
solve()
Grid backtracking: mark cell as visited, explore 4 directions, unmark on backtrack. O(m*n*4^L).
def exist(board, word):
m, n = len(board), len(board[0])
def dfs(i, j, k):
if k == len(word): return True
if i < 0 or i >= m or j < 0 or j >= n: return False
if board[i][j] != word[k]: return False
tmp, board[i][j] = board[i][j], '#' # mark visited
for di, dj in (0,1),(0,-1),(1,0),(-1,0):
if dfs(i+di, j+dj, k+1): return True
board[i][j] = tmp # unmark
return False
for i in range(m):
for j in range(n):
if dfs(i, j, 0): return True
return False
Reuse allowed: recurse with start=i (not i+1). Sort + break when candidate > remaining for pruning.
def combinationSum(candidates, target):
res = []
def backtrack(start, path, remaining):
if remaining == 0:
res.append(path[:])
return
for i in range(start, len(candidates)):
if candidates[i] > remaining: break
path.append(candidates[i])
backtrack(i, path, remaining - candidates[i]) # i, not i+1
path.pop()
candidates.sort()
backtrack(0, [], target)
return res
No reuse: recurse with i+1. Skip duplicates: if i > start and same as previous, skip. Must sort first.
def combinationSum2(candidates, target):
res = []
candidates.sort()
def backtrack(start, path, remaining):
if remaining == 0:
res.append(path[:])
return
for i in range(start, len(candidates)):
if candidates[i] > remaining: break
if i > start and candidates[i] == candidates[i-1]: continue
path.append(candidates[i])
backtrack(i + 1, path, remaining - candidates[i])
path.pop()
backtrack(0, [], target)
return res
Try every prefix as a partition. If it's a palindrome, recurse on the rest. Enumerate all valid partitions.
def partition(s):
res = []
def backtrack(start, path):
if start == len(s):
res.append(path[:])
return
for end in range(start + 1, len(s) + 1):
sub = s[start:end]
if sub == sub[::-1]:
path.append(sub)
backtrack(end, path)
path.pop()
backtrack(0, [])
return res
Place 3 dots in the string to create 4 parts. Each part: 1-3 digits, no leading zeros, value 0-255.
def restoreIpAddresses(s):
res = []
def backtrack(start, parts):
if len(parts) == 4:
if start == len(s):
res.append('.'.join(parts))
return
for length in range(1, 4):
if start + length > len(s): break
seg = s[start:start+length]
if (seg[0] == '0' and len(seg) > 1) or int(seg) > 255:
continue
backtrack(start + length, parts + [seg])
backtrack(0, [])
return res
Each digit maps to letters. At each position, branch on all mapped letters. O(4^n) worst case.
def letterCombinations(digits):
if not digits: return []
phone = {'2':'abc','3':'def','4':'ghi','5':'jkl',
'6':'mno','7':'pqrs','8':'tuv','9':'wxyz'}
res = []
def backtrack(i, path):
if i == len(digits):
res.append(''.join(path))
return
for ch in phone[digits[i]]:
path.append(ch)
backtrack(i + 1, path)
path.pop()
backtrack(0, [])
return res
Try placing each number into k buckets. Prune: skip if bucket overflows, skip symmetric bucket states.
def canPartitionKSubsets(nums, k):
total = sum(nums)
if total % k: return False
target = total // k
nums.sort(reverse=True)
buckets = [0] * k
def backtrack(i):
if i == len(nums): return all(b == target for b in buckets)
seen = set()
for j in range(k):
if buckets[j] + nums[i] > target: continue
if buckets[j] in seen: continue # prune symmetric
seen.add(buckets[j])
buckets[j] += nums[i]
if backtrack(i + 1): return True
buckets[j] -= nums[i]
return False
return backtrack(0)
Grid backtracking: visit cells with gold, mark visited by zeroing, restore on backtrack. Try every start.
def getMaximumGold(grid):
m, n = len(grid), len(grid[0])
def dfs(i, j):
if i < 0 or i >= m or j < 0 or j >= n or grid[i][j] == 0:
return 0
val = grid[i][j]
grid[i][j] = 0 # mark visited
best = 0
for di, dj in (0,1),(0,-1),(1,0),(-1,0):
best = max(best, dfs(i+di, j+dj))
grid[i][j] = val # unmark
return val + best
return max(dfs(i, j)
for i in range(m) for j in range(n)
if grid[i][j] > 0)
Permutations use a used[] mask instead of start-index. Each level tries all unused elements. O(n!) results.
def permute(nums):
res = []
def bt(path, used):
if len(path) == len(nums):
res.append(path[:])
return
for i in range(len(nums)):
if not used[i]:
used[i] = True
path.append(nums[i])
bt(path, used)
path.pop()
used[i] = False
bt([], [False]*len(nums))
return res
At each index, branch into include or exclude. Generates all 2^n subsets. Alternative: iterative bit-mask approach.
def subsets(nums):
res = []
def bt(i, path):
if i == len(nums):
res.append(path[:])
return
bt(i+1, path) # exclude
path.append(nums[i])
bt(i+1, path) # include
path.pop()
bt(0, [])
return res