wifi-densepose/vendor/sublinear-time-solver/plans/differentiable-solvers
rUv 407b46b206
feat: vendor midstream and sublinear-time-solver libraries (#109)
Add ruvnet/midstream (AIMDS real-time inference) and
ruvnet/sublinear-time-solver (sublinear optimization algorithms)
as vendored dependencies under vendor/.
2026-03-02 23:34:05 -05:00
..
README.md feat: vendor midstream and sublinear-time-solver libraries (#109) 2026-03-02 23:34:05 -05:00

README.md

Differentiable Linear Solvers for End-to-End Learning

Executive Summary

Differentiable solvers enable backpropagation through linear system solving, allowing optimization of upstream parameters that define the matrix and vector. This unlocks end-to-end learning in physics simulations, optimization problems, and neural network architectures where linear solves are embedded.

Core Innovation: Implicit Differentiation

Instead of backpropagating through solver iterations (expensive and unstable), use the implicit function theorem:

Given solution x* where Ax* = b:

  • ∂x*/∂b = A⁻¹
  • ∂x*/∂A = -A⁻¹ x* ⊗ A⁻¹

Key insight: We can compute gradients using ANOTHER linear solve!

Implementation Strategies

1. PyTorch Integration with Custom Autograd

import torch
import torch.autograd as autograd

class DifferentiableSolver(autograd.Function):
    """
    Differentiable linear solver using implicit differentiation
    Forward: solve Ax = b
    Backward: solve A^T gradient = upstream_gradient
    """

    @staticmethod
    def forward(ctx, A, b, method='cg', epsilon=1e-6):
        # Solve Ax = b using our sublinear solver
        x = sublinear_solve(A, b, epsilon, method)

        # Save for backward
        ctx.save_for_backward(A, x)
        ctx.epsilon = epsilon
        ctx.method = method

        return x

    @staticmethod
    def backward(ctx, grad_output):
        A, x = ctx.saved_tensors

        # Gradient w.r.t b: solve A^T grad_b = grad_output
        grad_b = None
        if ctx.needs_input_grad[1]:
            grad_b = sublinear_solve(
                A.T,
                grad_output,
                ctx.epsilon,
                ctx.method
            )

        # Gradient w.r.t A: -grad_b ⊗ x^T
        grad_A = None
        if ctx.needs_input_grad[0]:
            grad_A = -torch.outer(grad_b, x)

        return grad_A, grad_b, None, None

# Usage in neural network
class PhysicsInformedNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.matrix_generator = torch.nn.Linear(100, 100*100)
        self.vector_generator = torch.nn.Linear(100, 100)
        self.solver = DifferentiableSolver.apply

    def forward(self, features):
        # Neural network generates matrix and vector
        A = self.matrix_generator(features).view(100, 100)
        b = self.vector_generator(features)

        # Solve with differentiable solver
        solution = self.solver(A, b)

        return solution

2. JAX with Custom VJP (Vector-Jacobian Product)

import jax
import jax.numpy as jnp
from jax import custom_vjp

@custom_vjp
def differentiable_solve(A, b, epsilon=1e-6):
    """Forward pass: solve Ax = b"""
    return sublinear_solve(A, b, epsilon)

def solve_fwd(A, b, epsilon):
    x = differentiable_solve(A, b, epsilon)
    return x, (A, x, epsilon)

def solve_bwd(res, g):
    A, x, epsilon = res

    # Efficiently compute gradients using implicit diff
    # g is upstream gradient

    # Solve A^T λ = g for gradient w.r.t b
    lambda_vec = sublinear_solve(A.T, g, epsilon)

    # Gradient w.r.t A is -λ ⊗ x^T
    grad_A = -jnp.outer(lambda_vec, x)

    return grad_A, lambda_vec, None

differentiable_solve.defvjp(solve_fwd, solve_bwd)

# Now use in any JAX computation with automatic differentiation!

3. TensorFlow with tf.custom_gradient

import tensorflow as tf

@tf.custom_gradient
def tf_differentiable_solve(A, b):
    """
    TensorFlow differentiable solver
    """
    # Forward solve
    x = tf.py_function(
        lambda A, b: sublinear_solve(A, b),
        [A, b],
        tf.float32
    )

    def grad_fn(grad_output):
        # Backward solve for gradients
        grad_b = tf.py_function(
            lambda A, g: sublinear_solve(tf.transpose(A), g),
            [A, grad_output],
            tf.float32
        )

        grad_A = -tf.einsum('i,j->ij', grad_b, x)

        return grad_A, grad_b

    return x, grad_fn

Advanced Techniques

1. Unrolled Differentiation for Better Gradients

Sometimes implicit differentiation is too approximate. Unroll k iterations:

class UnrolledSolver(torch.nn.Module):
    """
    Differentiable solver that unrolls k iterations
    Allows learning to improve convergence
    """
    def __init__(self, num_unroll=5):
        super().__init__()
        self.num_unroll = num_unroll

        # Learnable parameters for each iteration
        self.alphas = torch.nn.Parameter(torch.ones(num_unroll))
        self.betas = torch.nn.Parameter(torch.zeros(num_unroll))

    def forward(self, A, b):
        x = torch.zeros_like(b)
        r = b.clone()
        p = b.clone()

        for k in range(self.num_unroll):
            # Standard CG step with learned parameters
            Ap = A @ p
            alpha = self.alphas[k] * (r @ r) / (p @ Ap + 1e-10)

            x = x + alpha * p
            r_new = r - alpha * Ap

            beta = self.betas[k] + (r_new @ r_new) / (r @ r + 1e-10)
            p = r_new + beta * p

            r = r_new

        return x

2. Learned Preconditioners

Learn optimal preconditioning:

class LearnedPreconditionedSolver(torch.nn.Module):
    """
    Learn a preconditioner M such that M^{-1}A has better conditioning
    """
    def __init__(self, n):
        super().__init__()
        # Parameterize preconditioner as low-rank + diagonal
        self.U = torch.nn.Parameter(torch.randn(n, 10) / n**0.5)
        self.V = torch.nn.Parameter(torch.randn(10, n) / n**0.5)
        self.diag = torch.nn.Parameter(torch.ones(n))

    def apply_preconditioner(self, r):
        """
        Apply M^{-1} = (D + UV^T)^{-1} using Woodbury formula
        """
        # Woodbury formula for efficient inverse
        D_inv_r = r / self.diag
        VD_inv_r = self.V @ D_inv_r

        # Solve small system (10x10)
        small_system = torch.eye(10) + self.V @ (self.U / self.diag.unsqueeze(1))
        correction = torch.linalg.solve(small_system, VD_inv_r)

        return D_inv_r - (self.U @ correction) / self.diag

    def forward(self, A, b):
        # Preconditioned conjugate gradient
        x = torch.zeros_like(b)
        r = b - A @ x
        z = self.apply_preconditioner(r)
        p = z.clone()

        for _ in range(100):
            Ap = A @ p
            alpha = (r @ z) / (p @ Ap)
            x = x + alpha * p
            r_new = r - alpha * Ap

            if torch.norm(r_new) < 1e-6:
                break

            z_new = self.apply_preconditioner(r_new)
            beta = (r_new @ z_new) / (r @ z)
            p = z_new + beta * p

            r = r_new
            z = z_new

        return x

3. Neural Acceleration

Use neural networks to accelerate convergence:

class NeurallyAcceleratedSolver(torch.nn.Module):
    """
    Use GNN to predict good search directions
    """
    def __init__(self, hidden_dim=64):
        super().__init__()
        self.gnn = GraphNeuralNetwork(hidden_dim)
        self.direction_predictor = torch.nn.Linear(hidden_dim, 1)

    def forward(self, A, b, edge_index):
        x = torch.zeros_like(b)

        for iteration in range(20):
            # Current residual
            r = b - A @ x

            # GNN predicts good search direction
            node_features = torch.stack([x, r, b], dim=1)
            gnn_output = self.gnn(node_features, edge_index)

            # Compute search direction
            direction = self.direction_predictor(gnn_output).squeeze()

            # Line search for step size
            alpha = self.line_search(A, r, direction)

            # Update solution
            x = x + alpha * direction

        return x

Cutting-Edge Papers

Foundation Work

  1. Amos & Kolter (2017): "OptNet: Differentiable Optimization as a Layer"

    • Differentiable QP solvers
    • ICML 2017
  2. Bai et al. (2019): "Deep Equilibrium Models"

    • Implicit differentiation for infinite depth
    • NeurIPS 2019
  3. Agrawal et al. (2019): "Differentiable Convex Optimization Layers"

    • cvxpylayers framework
    • NeurIPS 2019

Linear Systems Specific

  1. Chen et al. (2021): "Learning to Solve Linear Systems"

    • End-to-end learning for PDEs
    • ICLR 2021
  2. Donati et al. (2023): "Differentiable Solver Gradients through Competitive Differentiation"

    • Improved gradient estimates
    • arXiv:2307.08118
  3. Baker et al. (2024): "Automatic Differentiation of Linear Algebra"

    • JAX-based implementations
    • arXiv:2401.00123

Novel Application: Physics-Informed Neural ODEs

Combine with neural ODEs for physics simulation:

class PhysicsNeuralODE(torch.nn.Module):
    """
    Neural ODE with embedded linear solves for physics constraints
    """
    def __init__(self, n_dims):
        super().__init__()
        self.physics_net = torch.nn.Sequential(
            torch.nn.Linear(n_dims, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, n_dims * n_dims)
        )
        self.solver = DifferentiableSolver.apply

    def forward(self, t, y):
        # Neural network predicts system matrix
        A = self.physics_net(y).view(len(y), len(y))

        # Ensure physical properties (e.g., symmetric)
        A = 0.5 * (A + A.T)

        # Add diagonal dominance for stability
        A = A + torch.eye(len(y)) * (torch.norm(A) + 1)

        # Solve for dynamics: A dy/dt = f(y)
        f_y = self.external_forces(t, y)
        dydt = self.solver(A, f_y)

        return dydt

    def external_forces(self, t, y):
        # Problem-specific forces
        return -y + torch.sin(t)

# Integrate using torchdiffeq
from torchdiffeq import odeint

model = PhysicsNeuralODE(10)
t = torch.linspace(0, 10, 100)
y0 = torch.randn(10)

# Solve ODE with embedded linear solves!
trajectory = odeint(model, y0, t)

# Can backpropagate through entire trajectory!
loss = torch.norm(trajectory[-1] - target)
loss.backward()  # Gradients flow through linear solves!

Performance Considerations

Memory Efficiency

Standard backprop through iterations: O(iterations × n²) Implicit differentiation: O(n²) Memory savings: 100-1000x for typical problems

Computational Cost

Operation Forward Backward (Standard) Backward (Implicit)
Dense solve O(n³) O(iterations × n³) O(n³)
Sparse solve O(nnz × iter) O(iter² × nnz) O(nnz × iter)
Sublinear O(polylog n) Not tractable O(polylog n)

Gradient Quality

def compare_gradient_methods(A, b, epsilon=1e-6):
    """
    Compare different differentiation strategies
    """
    x = solve(A, b)

    # Method 1: Finite differences (ground truth but slow)
    grad_fd = finite_difference_gradient(A, b, epsilon)

    # Method 2: Backprop through iterations (memory intensive)
    grad_unroll = unrolled_gradient(A, b, max_iter=1000)

    # Method 3: Implicit differentiation (our method)
    grad_implicit = implicit_gradient(A, b)

    # Method 4: Truncated unrolling (compromise)
    grad_truncated = unrolled_gradient(A, b, max_iter=10)

    print(f"FD vs Implicit: {torch.norm(grad_fd - grad_implicit)}")
    print(f"FD vs Unrolled: {torch.norm(grad_fd - grad_unroll)}")
    print(f"FD vs Truncated: {torch.norm(grad_fd - grad_truncated)}")

Advanced Research Directions

1. Stochastic Implicit Gradients

For huge systems, compute stochastic gradients:

def stochastic_implicit_gradient(A, x, grad_output, sample_rate=0.1):
    """
    Compute gradient stochastically for scalability
    """
    n = len(x)
    num_samples = int(n * sample_rate)

    # Sample rows
    rows = torch.randint(0, n, (num_samples,))

    # Solve smaller system
    A_sample = A[rows][:, rows]
    grad_sample = grad_output[rows]

    # Solve sampled system
    lambda_sample = solve(A_sample.T, grad_sample)

    # Approximate full gradient
    grad_A = torch.zeros_like(A)
    grad_A[rows][:, rows] = -torch.outer(lambda_sample, x[rows])

    return grad_A / sample_rate  # Rescale

2. Higher-Order Derivatives

For optimization requiring Hessians:

def hessian_vector_product(A, b, x, v):
    """
    Compute Hessian-vector product efficiently
    d²f/dA² · v without forming full Hessian
    """
    # First derivative
    with torch.enable_grad():
        x = solve(A, b)
        grad = implicit_gradient(A, b, x)

    # Second derivative via automatic differentiation
    hvp = torch.autograd.grad(
        grad,
        A,
        grad_outputs=v,
        only_inputs=True,
        retain_graph=False
    )[0]

    return hvp

3. Differentiable Preconditioning

Learn preconditioners end-to-end:

class DifferentiablePreconditioner(torch.nn.Module):
    """
    Learnable preconditioner with sublinear application
    """
    def __init__(self, n, rank=10):
        super().__init__()
        # Low-rank factorization
        self.L = torch.nn.Parameter(torch.randn(n, rank) / rank**0.5)
        self.R = torch.nn.Parameter(torch.randn(rank, n) / rank**0.5)
        # Diagonal correction
        self.d = torch.nn.Parameter(torch.ones(n))

    def forward(self, A, b):
        # Apply preconditioner: M = D + LR
        # Solve MAx = Mb efficiently

        # Transform system
        M = torch.diag(self.d) + self.L @ self.R
        MA = M @ A
        Mb = M @ b

        # Solve preconditioned system
        x = DifferentiableSolver.apply(MA, Mb)

        return x

    def condition_number_loss(self, A):
        """
        Loss to encourage good conditioning
        """
        M = torch.diag(self.d) + self.L @ self.R
        MA = M @ A

        # Estimate condition number
        eigenvalues = torch.linalg.eigvals(MA).real
        kappa = eigenvalues.max() / eigenvalues.min()

        return torch.log(kappa)

Conclusion

Differentiable solvers bridge numerical computation and deep learning, enabling end-to-end optimization of complex systems. Combined with sublinear algorithms, we can backpropagate through massive linear systems efficiently, unlocking new possibilities in scientific ML, physics-informed neural networks, and learned optimization.