Direct Methods for Sparse Matrices
Usage
The direct methods for sparse matrix solutions feature a matrix factorization for solving a set of equations. This procedure is call factorization or decomposition. We can also support factorization or decomposition via shared memory across kernels. The design is to store the factorized matrix in the C++ memory and pass an identifying code (an integer) to the solution operators. Here is how you solve $Ax_i = b_i$ for a list of $(x_i, b_i)$
Step 1: Factorization
A_factorized = factorize(A)
Step2: Solve
x1 = A_factorized\b1
x2 = A_factorized\b2
......
Compared to A\b
, the factorize-then-solve approach is more efficient, especially when you have to solve a lot of equations.
Control Flow Safety
The factorize-then-solve method is also control flow safe. That is, we can safely use it in the control flow and gradient backpropagation is correct. For example, if the matrix $A$ keeps unchanged throught the loop, we might want to factorize $A$ first and then use the factorized $A$ to solve equations repeatly. To verify the control flow safety, consider the following code, where in the loop we have
\[u_{i+1} = A^{-1}(u_i + r), i=1,2,\ldots\]
using ADCME
using SparseArrays
using ADCMEKit
function while_loop_simulation(vv, rhs, ns = 10)
A = SparseTensor(ii, jj, vv, 10, 10) + spdiag(10)*100.
Afac = factorize(A)
ta = TensorArray(ns)
i = constant(2, dtype=Int32)
ta = write(ta, 1, ones(10))
function condition(i, ta)
i<= ns
end
function body(i, ta)
u = read(ta, i-1)
res = Afac\(u + rhs)
# res = u
ta = write(ta, i, res)
i+1, ta
end
_, out = while_loop(condition, body, [i, ta])
sum(stack(out)^2)
end
sess = Session(); init(sess)
A = sprand(10, 10, 0.8)
ii, jj, vv = find(constant(A))
k = length(vv)
# Test 1: autodiff through A
pl = placeholder(rand(k))
res = while_loop_simulation(pl, rhs , 100)
gradview(sess, pl, res, rand(k))
# Test 2: autodiff through rhs
pl = placeholder(rand(10))
res = while_loop_simulation(vv, pl , 100)
gradview(sess, pl, res, rand(10))
We have the following convergence plot