515 lines
14 KiB
Markdown
515 lines
14 KiB
Markdown
# Differentiable Linear Solvers for End-to-End Learning
|
||
|
||
## Executive Summary
|
||
|
||
Differentiable solvers enable backpropagation through linear system solving, allowing optimization of upstream parameters that define the matrix and vector. This unlocks end-to-end learning in physics simulations, optimization problems, and neural network architectures where linear solves are embedded.
|
||
|
||
## Core Innovation: Implicit Differentiation
|
||
|
||
Instead of backpropagating through solver iterations (expensive and unstable), use the implicit function theorem:
|
||
|
||
Given solution x* where Ax* = b:
|
||
- ∂x*/∂b = A⁻¹
|
||
- ∂x*/∂A = -A⁻¹ x* ⊗ A⁻¹
|
||
|
||
**Key insight**: We can compute gradients using ANOTHER linear solve!
|
||
|
||
## Implementation Strategies
|
||
|
||
### 1. PyTorch Integration with Custom Autograd
|
||
|
||
```python
|
||
import torch
|
||
import torch.autograd as autograd
|
||
|
||
class DifferentiableSolver(autograd.Function):
|
||
"""
|
||
Differentiable linear solver using implicit differentiation
|
||
Forward: solve Ax = b
|
||
Backward: solve A^T gradient = upstream_gradient
|
||
"""
|
||
|
||
@staticmethod
|
||
def forward(ctx, A, b, method='cg', epsilon=1e-6):
|
||
# Solve Ax = b using our sublinear solver
|
||
x = sublinear_solve(A, b, epsilon, method)
|
||
|
||
# Save for backward
|
||
ctx.save_for_backward(A, x)
|
||
ctx.epsilon = epsilon
|
||
ctx.method = method
|
||
|
||
return x
|
||
|
||
@staticmethod
|
||
def backward(ctx, grad_output):
|
||
A, x = ctx.saved_tensors
|
||
|
||
# Gradient w.r.t b: solve A^T grad_b = grad_output
|
||
grad_b = None
|
||
if ctx.needs_input_grad[1]:
|
||
grad_b = sublinear_solve(
|
||
A.T,
|
||
grad_output,
|
||
ctx.epsilon,
|
||
ctx.method
|
||
)
|
||
|
||
# Gradient w.r.t A: -grad_b ⊗ x^T
|
||
grad_A = None
|
||
if ctx.needs_input_grad[0]:
|
||
grad_A = -torch.outer(grad_b, x)
|
||
|
||
return grad_A, grad_b, None, None
|
||
|
||
# Usage in neural network
|
||
class PhysicsInformedNN(torch.nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.matrix_generator = torch.nn.Linear(100, 100*100)
|
||
self.vector_generator = torch.nn.Linear(100, 100)
|
||
self.solver = DifferentiableSolver.apply
|
||
|
||
def forward(self, features):
|
||
# Neural network generates matrix and vector
|
||
A = self.matrix_generator(features).view(100, 100)
|
||
b = self.vector_generator(features)
|
||
|
||
# Solve with differentiable solver
|
||
solution = self.solver(A, b)
|
||
|
||
return solution
|
||
```
|
||
|
||
### 2. JAX with Custom VJP (Vector-Jacobian Product)
|
||
|
||
```python
|
||
import jax
|
||
import jax.numpy as jnp
|
||
from jax import custom_vjp
|
||
|
||
@custom_vjp
|
||
def differentiable_solve(A, b, epsilon=1e-6):
|
||
"""Forward pass: solve Ax = b"""
|
||
return sublinear_solve(A, b, epsilon)
|
||
|
||
def solve_fwd(A, b, epsilon):
|
||
x = differentiable_solve(A, b, epsilon)
|
||
return x, (A, x, epsilon)
|
||
|
||
def solve_bwd(res, g):
|
||
A, x, epsilon = res
|
||
|
||
# Efficiently compute gradients using implicit diff
|
||
# g is upstream gradient
|
||
|
||
# Solve A^T λ = g for gradient w.r.t b
|
||
lambda_vec = sublinear_solve(A.T, g, epsilon)
|
||
|
||
# Gradient w.r.t A is -λ ⊗ x^T
|
||
grad_A = -jnp.outer(lambda_vec, x)
|
||
|
||
return grad_A, lambda_vec, None
|
||
|
||
differentiable_solve.defvjp(solve_fwd, solve_bwd)
|
||
|
||
# Now use in any JAX computation with automatic differentiation!
|
||
```
|
||
|
||
### 3. TensorFlow with tf.custom_gradient
|
||
|
||
```python
|
||
import tensorflow as tf
|
||
|
||
@tf.custom_gradient
|
||
def tf_differentiable_solve(A, b):
|
||
"""
|
||
TensorFlow differentiable solver
|
||
"""
|
||
# Forward solve
|
||
x = tf.py_function(
|
||
lambda A, b: sublinear_solve(A, b),
|
||
[A, b],
|
||
tf.float32
|
||
)
|
||
|
||
def grad_fn(grad_output):
|
||
# Backward solve for gradients
|
||
grad_b = tf.py_function(
|
||
lambda A, g: sublinear_solve(tf.transpose(A), g),
|
||
[A, grad_output],
|
||
tf.float32
|
||
)
|
||
|
||
grad_A = -tf.einsum('i,j->ij', grad_b, x)
|
||
|
||
return grad_A, grad_b
|
||
|
||
return x, grad_fn
|
||
```
|
||
|
||
## Advanced Techniques
|
||
|
||
### 1. Unrolled Differentiation for Better Gradients
|
||
|
||
Sometimes implicit differentiation is too approximate. Unroll k iterations:
|
||
|
||
```python
|
||
class UnrolledSolver(torch.nn.Module):
|
||
"""
|
||
Differentiable solver that unrolls k iterations
|
||
Allows learning to improve convergence
|
||
"""
|
||
def __init__(self, num_unroll=5):
|
||
super().__init__()
|
||
self.num_unroll = num_unroll
|
||
|
||
# Learnable parameters for each iteration
|
||
self.alphas = torch.nn.Parameter(torch.ones(num_unroll))
|
||
self.betas = torch.nn.Parameter(torch.zeros(num_unroll))
|
||
|
||
def forward(self, A, b):
|
||
x = torch.zeros_like(b)
|
||
r = b.clone()
|
||
p = b.clone()
|
||
|
||
for k in range(self.num_unroll):
|
||
# Standard CG step with learned parameters
|
||
Ap = A @ p
|
||
alpha = self.alphas[k] * (r @ r) / (p @ Ap + 1e-10)
|
||
|
||
x = x + alpha * p
|
||
r_new = r - alpha * Ap
|
||
|
||
beta = self.betas[k] + (r_new @ r_new) / (r @ r + 1e-10)
|
||
p = r_new + beta * p
|
||
|
||
r = r_new
|
||
|
||
return x
|
||
```
|
||
|
||
### 2. Learned Preconditioners
|
||
|
||
Learn optimal preconditioning:
|
||
|
||
```python
|
||
class LearnedPreconditionedSolver(torch.nn.Module):
|
||
"""
|
||
Learn a preconditioner M such that M^{-1}A has better conditioning
|
||
"""
|
||
def __init__(self, n):
|
||
super().__init__()
|
||
# Parameterize preconditioner as low-rank + diagonal
|
||
self.U = torch.nn.Parameter(torch.randn(n, 10) / n**0.5)
|
||
self.V = torch.nn.Parameter(torch.randn(10, n) / n**0.5)
|
||
self.diag = torch.nn.Parameter(torch.ones(n))
|
||
|
||
def apply_preconditioner(self, r):
|
||
"""
|
||
Apply M^{-1} = (D + UV^T)^{-1} using Woodbury formula
|
||
"""
|
||
# Woodbury formula for efficient inverse
|
||
D_inv_r = r / self.diag
|
||
VD_inv_r = self.V @ D_inv_r
|
||
|
||
# Solve small system (10x10)
|
||
small_system = torch.eye(10) + self.V @ (self.U / self.diag.unsqueeze(1))
|
||
correction = torch.linalg.solve(small_system, VD_inv_r)
|
||
|
||
return D_inv_r - (self.U @ correction) / self.diag
|
||
|
||
def forward(self, A, b):
|
||
# Preconditioned conjugate gradient
|
||
x = torch.zeros_like(b)
|
||
r = b - A @ x
|
||
z = self.apply_preconditioner(r)
|
||
p = z.clone()
|
||
|
||
for _ in range(100):
|
||
Ap = A @ p
|
||
alpha = (r @ z) / (p @ Ap)
|
||
x = x + alpha * p
|
||
r_new = r - alpha * Ap
|
||
|
||
if torch.norm(r_new) < 1e-6:
|
||
break
|
||
|
||
z_new = self.apply_preconditioner(r_new)
|
||
beta = (r_new @ z_new) / (r @ z)
|
||
p = z_new + beta * p
|
||
|
||
r = r_new
|
||
z = z_new
|
||
|
||
return x
|
||
```
|
||
|
||
### 3. Neural Acceleration
|
||
|
||
Use neural networks to accelerate convergence:
|
||
|
||
```python
|
||
class NeurallyAcceleratedSolver(torch.nn.Module):
|
||
"""
|
||
Use GNN to predict good search directions
|
||
"""
|
||
def __init__(self, hidden_dim=64):
|
||
super().__init__()
|
||
self.gnn = GraphNeuralNetwork(hidden_dim)
|
||
self.direction_predictor = torch.nn.Linear(hidden_dim, 1)
|
||
|
||
def forward(self, A, b, edge_index):
|
||
x = torch.zeros_like(b)
|
||
|
||
for iteration in range(20):
|
||
# Current residual
|
||
r = b - A @ x
|
||
|
||
# GNN predicts good search direction
|
||
node_features = torch.stack([x, r, b], dim=1)
|
||
gnn_output = self.gnn(node_features, edge_index)
|
||
|
||
# Compute search direction
|
||
direction = self.direction_predictor(gnn_output).squeeze()
|
||
|
||
# Line search for step size
|
||
alpha = self.line_search(A, r, direction)
|
||
|
||
# Update solution
|
||
x = x + alpha * direction
|
||
|
||
return x
|
||
```
|
||
|
||
## Cutting-Edge Papers
|
||
|
||
### Foundation Work
|
||
|
||
1. **Amos & Kolter (2017)**: "OptNet: Differentiable Optimization as a Layer"
|
||
- Differentiable QP solvers
|
||
- ICML 2017
|
||
|
||
2. **Bai et al. (2019)**: "Deep Equilibrium Models"
|
||
- Implicit differentiation for infinite depth
|
||
- NeurIPS 2019
|
||
|
||
3. **Agrawal et al. (2019)**: "Differentiable Convex Optimization Layers"
|
||
- cvxpylayers framework
|
||
- NeurIPS 2019
|
||
|
||
### Linear Systems Specific
|
||
|
||
4. **Chen et al. (2021)**: "Learning to Solve Linear Systems"
|
||
- End-to-end learning for PDEs
|
||
- ICLR 2021
|
||
|
||
5. **Donati et al. (2023)**: "Differentiable Solver Gradients through Competitive Differentiation"
|
||
- Improved gradient estimates
|
||
- arXiv:2307.08118
|
||
|
||
6. **Baker et al. (2024)**: "Automatic Differentiation of Linear Algebra"
|
||
- JAX-based implementations
|
||
- arXiv:2401.00123
|
||
|
||
## Novel Application: Physics-Informed Neural ODEs
|
||
|
||
Combine with neural ODEs for physics simulation:
|
||
|
||
```python
|
||
class PhysicsNeuralODE(torch.nn.Module):
|
||
"""
|
||
Neural ODE with embedded linear solves for physics constraints
|
||
"""
|
||
def __init__(self, n_dims):
|
||
super().__init__()
|
||
self.physics_net = torch.nn.Sequential(
|
||
torch.nn.Linear(n_dims, 128),
|
||
torch.nn.ReLU(),
|
||
torch.nn.Linear(128, n_dims * n_dims)
|
||
)
|
||
self.solver = DifferentiableSolver.apply
|
||
|
||
def forward(self, t, y):
|
||
# Neural network predicts system matrix
|
||
A = self.physics_net(y).view(len(y), len(y))
|
||
|
||
# Ensure physical properties (e.g., symmetric)
|
||
A = 0.5 * (A + A.T)
|
||
|
||
# Add diagonal dominance for stability
|
||
A = A + torch.eye(len(y)) * (torch.norm(A) + 1)
|
||
|
||
# Solve for dynamics: A dy/dt = f(y)
|
||
f_y = self.external_forces(t, y)
|
||
dydt = self.solver(A, f_y)
|
||
|
||
return dydt
|
||
|
||
def external_forces(self, t, y):
|
||
# Problem-specific forces
|
||
return -y + torch.sin(t)
|
||
|
||
# Integrate using torchdiffeq
|
||
from torchdiffeq import odeint
|
||
|
||
model = PhysicsNeuralODE(10)
|
||
t = torch.linspace(0, 10, 100)
|
||
y0 = torch.randn(10)
|
||
|
||
# Solve ODE with embedded linear solves!
|
||
trajectory = odeint(model, y0, t)
|
||
|
||
# Can backpropagate through entire trajectory!
|
||
loss = torch.norm(trajectory[-1] - target)
|
||
loss.backward() # Gradients flow through linear solves!
|
||
```
|
||
|
||
## Performance Considerations
|
||
|
||
### Memory Efficiency
|
||
|
||
Standard backprop through iterations: O(iterations × n²)
|
||
Implicit differentiation: O(n²)
|
||
**Memory savings**: 100-1000x for typical problems
|
||
|
||
### Computational Cost
|
||
|
||
| Operation | Forward | Backward (Standard) | Backward (Implicit) |
|
||
|-----------|---------|-------------------|-------------------|
|
||
| Dense solve | O(n³) | O(iterations × n³) | O(n³) |
|
||
| Sparse solve | O(nnz × iter) | O(iter² × nnz) | O(nnz × iter) |
|
||
| Sublinear | O(polylog n) | Not tractable | O(polylog n) |
|
||
|
||
### Gradient Quality
|
||
|
||
```python
|
||
def compare_gradient_methods(A, b, epsilon=1e-6):
|
||
"""
|
||
Compare different differentiation strategies
|
||
"""
|
||
x = solve(A, b)
|
||
|
||
# Method 1: Finite differences (ground truth but slow)
|
||
grad_fd = finite_difference_gradient(A, b, epsilon)
|
||
|
||
# Method 2: Backprop through iterations (memory intensive)
|
||
grad_unroll = unrolled_gradient(A, b, max_iter=1000)
|
||
|
||
# Method 3: Implicit differentiation (our method)
|
||
grad_implicit = implicit_gradient(A, b)
|
||
|
||
# Method 4: Truncated unrolling (compromise)
|
||
grad_truncated = unrolled_gradient(A, b, max_iter=10)
|
||
|
||
print(f"FD vs Implicit: {torch.norm(grad_fd - grad_implicit)}")
|
||
print(f"FD vs Unrolled: {torch.norm(grad_fd - grad_unroll)}")
|
||
print(f"FD vs Truncated: {torch.norm(grad_fd - grad_truncated)}")
|
||
```
|
||
|
||
## Advanced Research Directions
|
||
|
||
### 1. Stochastic Implicit Gradients
|
||
|
||
For huge systems, compute stochastic gradients:
|
||
|
||
```python
|
||
def stochastic_implicit_gradient(A, x, grad_output, sample_rate=0.1):
|
||
"""
|
||
Compute gradient stochastically for scalability
|
||
"""
|
||
n = len(x)
|
||
num_samples = int(n * sample_rate)
|
||
|
||
# Sample rows
|
||
rows = torch.randint(0, n, (num_samples,))
|
||
|
||
# Solve smaller system
|
||
A_sample = A[rows][:, rows]
|
||
grad_sample = grad_output[rows]
|
||
|
||
# Solve sampled system
|
||
lambda_sample = solve(A_sample.T, grad_sample)
|
||
|
||
# Approximate full gradient
|
||
grad_A = torch.zeros_like(A)
|
||
grad_A[rows][:, rows] = -torch.outer(lambda_sample, x[rows])
|
||
|
||
return grad_A / sample_rate # Rescale
|
||
```
|
||
|
||
### 2. Higher-Order Derivatives
|
||
|
||
For optimization requiring Hessians:
|
||
|
||
```python
|
||
def hessian_vector_product(A, b, x, v):
|
||
"""
|
||
Compute Hessian-vector product efficiently
|
||
d²f/dA² · v without forming full Hessian
|
||
"""
|
||
# First derivative
|
||
with torch.enable_grad():
|
||
x = solve(A, b)
|
||
grad = implicit_gradient(A, b, x)
|
||
|
||
# Second derivative via automatic differentiation
|
||
hvp = torch.autograd.grad(
|
||
grad,
|
||
A,
|
||
grad_outputs=v,
|
||
only_inputs=True,
|
||
retain_graph=False
|
||
)[0]
|
||
|
||
return hvp
|
||
```
|
||
|
||
### 3. Differentiable Preconditioning
|
||
|
||
Learn preconditioners end-to-end:
|
||
|
||
```python
|
||
class DifferentiablePreconditioner(torch.nn.Module):
|
||
"""
|
||
Learnable preconditioner with sublinear application
|
||
"""
|
||
def __init__(self, n, rank=10):
|
||
super().__init__()
|
||
# Low-rank factorization
|
||
self.L = torch.nn.Parameter(torch.randn(n, rank) / rank**0.5)
|
||
self.R = torch.nn.Parameter(torch.randn(rank, n) / rank**0.5)
|
||
# Diagonal correction
|
||
self.d = torch.nn.Parameter(torch.ones(n))
|
||
|
||
def forward(self, A, b):
|
||
# Apply preconditioner: M = D + LR
|
||
# Solve MAx = Mb efficiently
|
||
|
||
# Transform system
|
||
M = torch.diag(self.d) + self.L @ self.R
|
||
MA = M @ A
|
||
Mb = M @ b
|
||
|
||
# Solve preconditioned system
|
||
x = DifferentiableSolver.apply(MA, Mb)
|
||
|
||
return x
|
||
|
||
def condition_number_loss(self, A):
|
||
"""
|
||
Loss to encourage good conditioning
|
||
"""
|
||
M = torch.diag(self.d) + self.L @ self.R
|
||
MA = M @ A
|
||
|
||
# Estimate condition number
|
||
eigenvalues = torch.linalg.eigvals(MA).real
|
||
kappa = eigenvalues.max() / eigenvalues.min()
|
||
|
||
return torch.log(kappa)
|
||
```
|
||
|
||
## Conclusion
|
||
|
||
Differentiable solvers bridge numerical computation and deep learning, enabling end-to-end optimization of complex systems. Combined with sublinear algorithms, we can backpropagate through massive linear systems efficiently, unlocking new possibilities in scientific ML, physics-informed neural networks, and learned optimization. |