본문 바로가기

알고리즘(Algorithm)/Algorithm

[알고리즘] 연쇄 행렬 곱셈 알고리즘(Chained Matrix Multiplication)

반응형

 

연쇄 행렬 곱셈 알고리즘(Chained Matrix Multiplication)

연쇄 행렬(ex, 2 by 4 행렬 x 4 by 7 행렬 x 7 by 9 행렬 x ...)의 곱셈을 할 때 곱셈 연산을 최소로 하는 곱셈 순서를 찾는 알고리즘이다. 행렬 곱셈은 결합 법칙이 성립하기 때문에 곱하는 순서에 상관은 없지만, 순서에 따라 연산의 양이 달라진다.

 

2 x 3 행렬 A와 3 x 5 행렬 B, 5 x 7 행렬 C가 있다고 하자. 이 행렬들의 곱 ABC를 구하는 경우는 아래와 같다.

1. AB를 먼저 곱하고 C를 곱할 때의 연산 수 = 2 x 3 x 5 +  2 x 5 x 7 = 100

2. BC를 먼저 곱하고 A를 곱할 때의 연산 수 = 3 x 5 x 7 + 2 x 3 x 7 = 147

 

둘은 같은 곱셈이지만, 연산의 수가 다름을 알 수 있다. 아래는 연쇄행렬 최소곱셈 알고리즘의 점화식이다.

$$M[i][j] = \begin{cases} \min{(M[i][k] + M[k+1][j] + d_{i-1}d_{k}d_{j})}, & (i \le k \le j-1) \\ 0, & (i = j) \end{cases}$$

 

위 ABC 행렬을 예시로 해당 점화식을 전개해보자.

\(d_{0} = 2, d_{1} = 3, d_{2} = 5, d_{3} = 7\)

1. \(M[1][2]\) (행렬 A, B의 곱 연산 수) \( (1 \le k \le 1) \)

\( = \min{(M[1][k] + M[k+1][2] + d_{0}d_{k}d_{2})} \)

\( = M[1][1] + M[2][2] + d_{0}d_{1}d_{2} \)

\( = 0 + 0 + 2 \times 3 \times 5 \)

\( = 30 \)

 

2. \(M[2][3]\) (행렬 B, C의 곱 연산 수) \( (2 \le k \le 2) \)

\( = \min{(M[2][k] + M[k+1][3] + d_{1}d_{k}d_{3})} \)

\( = M[2][2] + M[3][3] + d_{1}d_{2}d_{3} \)

\( = 0 + 0 + 3 \times 5 \times 7 \)

\( = 105 \)

 

3. \(M[1][3]\) (행렬 A, B, C의 곱 연산 수) \( (1 \le k \le 2) \)

\( = \min{(M[1][k] + M[k+1][3] + d_{0}d_{k}d_{3})} \)

\( = \min{(M[1][1] + M[2][3] + d_{0}d_{1}d_{3}, M[1][2] + M[3][3] + d_{0}d_{2}d_{3})} \)

\( = \min{(0 + 105 + 2 \times 3 \times 7, 30 + 0 + 2 \times 5 \times 7)} \)

\( = \min{(147, 100)} \)

\( = 100 \)

 

최소 곱셈이 100회가 됨을 알 수 있다.

 

위의 예시를 코드로 나타내면 아래와 같다.

#include <iostream>

using namespace std;

int dp[4][4];

int main() {
    int d[4] = { 2, 3, 5, 7 };
    
    // 행렬의 수 만큼 반복
    for (int n = 0; n < 3; n++) {
        for (int i = 1; i <= 3 - n; i++) {
            int j = i + n;
            if (i == j) {
                dp[i][j] = 0;
            }
            else {
                dp[i][j] = 9999999;
                for (int k = i; k <= j - 1; k++) {
                    dp[i][j] = min(dp[i][j], dp[i][k] + dp[k + 1][j] + d[i - 1] * d[k] * d[j]);
                }
            }
        }
    }
    
    return 0;
}

 

[Reference]

https://en.wikipedia.org/wiki/Matrix_chain_multiplication

 

 

728x90