cg迭代法(Conjugate gradient method)求解线性方程组

Author : zbzhen,        Modified : Sat Sep 9 14:25:34 2023

1. 模型问题

矩阵AA为满秩对称矩阵,求线性方程组

Ax=bAx=b

2. 分析与推导

x0=0x_0=0, 如果非零向量组{P0,P1,P2,}\{P_0,P_1,P_2,\dots\}线性无关, 当k+1k+1等于向量xk+1x_{k+1}的维数时,则向量xk+1x_{k+1}可写成

xk+1=j=0kαjPjx_{k+1}=\sum_{j=0}^k \alpha_jP_j

其中αj\alpha_j为待定标量系数. 或者可写成迭代格式

xk+1=xk+αkPkx_{k+1}=x_k+ \alpha_k P_k

记残差

rk=bAxkr_k = b - A x_k

联立上面两式可得

rk+1=rkαkAPkr_{k+1} = r_k - \alpha_k AP_k

或可改写成

APk=1αk(rkrk+1)AP_k = \dfrac{1}{\alpha_k} (r_k-r_{k+1})

接下来就是需要想办法找到合适的αk\alpha_kPkP_k使得计算复杂度尽可能少

P0=r0P_0=r_0, 同时把PkP_{k}写成r0,r1,r2,,rkr_0,r_1,r_2,\cdots,r_k的线性组合, 使得

PjTAPk=0,jk.P_j^TAP_k = 0, \quad j\neq k.

也可写成递推格式,

Pk+1=rk+1+βkPkP_{k+1}=r_{k+1}+ \beta_k P_k

可选取合适的βk\beta_k使得

rkTrj=0,kj.r_k^T r_j = 0,\quad k\neq j.

于是联立上面四个等式,在上面倒数第二式左乘PkTAP_k^TA, 然后由正交性可得

βk=PkTArk+1PkTAPk=rk+1TAPkPkTAPk=rk+1T(rkrk+1)αkPkTAPk=rk+1Trk+1αkPkTAPk\begin{aligned} \beta_k &= -\dfrac{P_k^TAr_{k+1}}{P_k^TAP_k} = -\dfrac{r_{k+1}^TAP_k}{P_k^TAP_k} \\&= -\dfrac{r_{k+1}^T(r_{k}-{r}_{k+1})}{\alpha_k P_k^TAP_k} = \dfrac{r_{k+1}^Tr_{k+1}}{\alpha_k P_k^TAP_k} \end{aligned}

最后就是αk\alpha_k的推导. 根据

PkTAPk=(rk+βk1Pk1)TAPk=rkTAPk=1αkrkT(rkrk+1)=1αkrkTrk\begin{aligned} P_k^T A P_k &= (r_{k}+ \beta_{k-1} P_{k-1})^T A P_k = r_k^TAP_k \\&= \dfrac{1}{\alpha_k}r_k^T (r_k-r_{k+1}) = \dfrac{1}{\alpha_k} r_k^T r_k \end{aligned}

即得

αk=rkTrkPkTAPk\alpha_k = \dfrac{r_k^T r_k }{P_k^T A P_k}

进而得到

βk=rk+1Trk+1rkTrk\beta_k = \dfrac{r_{k+1}^Tr_{k+1}}{r_k^T r_k}

需要注意的是,上面的公式推导过程似乎并没有用到完全正交性, 只是用到了
Pk+1TAPk=rk+1Trk=0P_{k+1}^TAP_k=r_{k+1}^Tr_k=0
但事实上, 根据两个递推公式(P0=r0P_0=r_0)
Pk+1=rk+1+βkPkP_{k+1}=r_{k+1}+ \beta_k P_krk+1=rkαkAPkr_{k+1} = r_k - \alpha_k AP_k
可以得到完全正交公式PjTAPk=rkTrj=0,jkP_j^TAP_k = r_k^T r_j = 0, \quad j\neq k

3. 算法

3.1. 便于理解的粗糙算法

给定A,b,x0A,b,x_0, 求出x=A1bx=A^{-1}b

初始化:r0=bAx0r_0=b - Ax_0, P0=r0P_0=r_0, k=0k=0

如果没有收敛:

  • αk=rkTrkPkTAPk\alpha_k = \dfrac{r_k^T r_k }{P_k^T A P_k}

  • xk+1=xk+αkPkx_{k+1}=x_k+ \alpha_k P_k

  • rk+1=rkαkAPkr_{k+1} = r_k - \alpha_k AP_k

  • βk=rk+1Trk+1rkTrk\beta_k = \dfrac{r_{k+1}^Tr_{k+1}}{r_k^T r_k}

  • Pk+1=rk+1+βkPkP_{k+1}=r_{k+1}+ \beta_k P_k

  • kk+1k \leftarrow k+1

3.2. 优化后的算法

为了使得不重复计算APkAP_k以及rkTrkr_k^Tr_k,因此算法可以优化为:

给定A,b,x0A,b,x_0, 求出x=A1bx=A^{-1}b

初始化:r0=bAx0r_0=b - Ax_0, P0=r0P_0=r_0, k=0k=0, s0=r0Tr0s_0=r_0^Tr_0

如果没有收敛:

  • dk=APkd_k = AP_k
  • αk=skPkTdk\alpha_k = \dfrac{s_k }{P_k^T d_k}
  • xk+1=xk+αkPkx_{k+1}=x_k+ \alpha_k P_k
  • rk+1=rkαkdkr_{k+1} = r_k - \alpha_k d_k
  • sk+1=rk+1Trk+1s_{k+1} = r_{k+1}^Tr_{k+1}
  • Pk+1=rk+1+sk+1skPkP_{k+1}=r_{k+1}+ \dfrac{s_{k+1}}{s_k} P_k
  • kk+1k \leftarrow k+1

4. MATLAB / GNU Octave 程序实现

function x = conjgrad(A, b, x)
    r = b - A * x;
    p = r;
    rsold = r' * r;

    for i = 1:length(b)
        Ap = A * p;
        alpha = rsold / (p' * Ap);
        x = x + alpha * p;
        r = r - alpha * Ap;
        rsnew = r' * r;
        if sqrt(rsnew) < 1e-10
            break
        end
        p = r + (rsnew / rsold) * p;
        rsold = rsnew;
    end
end

5. python实现

5.1. 基本cg程序

import numpy as np
def cg(A, b, x):
    r = b - A @ x
    p = r
    rsold = r @ r
    for i in range(len(b)):
        Ap = A @ p
        alpha = rsold / (p @ Ap)
        x += alpha * p
        r -= alpha * Ap
        rsnew = r @ r
        if np.sqrt(rsnew) < 1e-10:
            break
        p = r + (rsnew / rsold) * p
        rsold = rsnew
    print(np.sqrt(rsnew), i)
    return x

5.2. 预处理cg程序

def pcg(A, b, x, Minv):
    r = b - A @ x
    z = Minv @ r
    p = z
    rsold = z @ r
    for i in range(len(b)):
        Ap = A @ p
        alpha = rsold / (p @ Ap)
        x += alpha * p
        r -= alpha * Ap
        z = Minv @ r
        rsnew = z @ r
        if np.sqrt(rsnew) < 1e-10:
            break
        p = z + (rsnew / rsold) * p
        rsold = rsnew
    print(np.sqrt(rsnew), i)
    return x