파이썬 행렬 곱셈 순서(BOJ 11049)


문제

크기가 N×M인 행렬 A와 M×K인 B를 곱할 때 필요한 곱셈 연산의 수는 총 N×M×K번이다. 행렬 N개를 곱하는데 필요한 곱셈 연산의 수는 행렬을 곱하는 순서에 따라 달라지게 된다.

예를 들어, A의 크기가 5×3이고, B의 크기가 3×2, C의 크기가 2×6인 경우에 행렬의 곱 ABC를 구하는 경우를 생각해보자.

  • AB를 먼저 곱하고 C를 곱하는 경우 (AB)C에 필요한 곱셈 연산의 수는 5×3×2 + 5×2×6 = 30 + 60 = 90번이다.
  • BC를 먼저 곱하고 A를 곱하는 경우 A(BC)에 필요한 곱셈 연산의 수는 3×2×6 + 5×3×6 = 36 + 90 = 126번이다.

같은 곱셈이지만, 곱셈을 하는 순서에 따라서 곱셈 연산의 수가 달라진다.

행렬 N개의 크기가 주어졌을 때, 모든 행렬을 곱하는데 필요한 곱셈 연산 횟수의 최솟값을 구하는 프로그램을 작성하시오. 입력으로 주어진 행렬의 순서를 바꾸면 안 된다.


입력

첫째 줄에 행렬의 개수 N(1 ≤ N ≤ 500)이 주어진다.

둘째 줄부터 N개 줄에는 행렬의 크기 r과 c가 주어진다. (1 ≤ r, c ≤ 500)

항상 순서대로 곱셈을 할 수 있는 크기만 입력으로 주어진다.


출력

첫째 줄에 입력으로 주어진 행렬을 곱하는데 필요한 곱셈 연산의 최솟값을 출력한다. 정답은 231-1 보다 작거나 같은 자연수이다. 또한, 최악의 순서로 연산해도 연산 횟수가 231-1보다 작거나 같다.


예제 입력 1

3
5 3
3 2
2 6

예제 출력 1

90


📝 풀어보기

문제의 예제에 값을 추가했다.

4
5 3
3 2
2 6
6 3

문제를 보면 (AB)C를 곱하는 경우, 5x3x2 + 5x2x6 = 90이 된다.

이것을 각 행렬로 풀어서 보면 (A행 A열 B열) + (A행 B열 C열)로 보인다.

아래 테이블을 보자

  A(5, 3) B(3, 2) C(2, 6) D(6, 3)
A(5, 3) 0 30(5x3x2) 30 + 60(5x2x6) = 90 96
B(3, 2) 0 0 36(3x2x6) 36 + 18(3x2x3) = 54
C(2, 6) 0 0 0 36(2x6x3)
D(6, 3) 0 0 0 0

AB의 경우 A행xA열xB열이다.

BC, CD도 마찬가지로 곱하는 앞 행렬의 행, 열과 뒷 행렬의 열을 곱해서 값이나온다.

(AB)C, (BC)D 의 경우에는 맨 처음행렬의 행, 그 다음행렬의 열, 마지막으로 곱하는 행렬의 열로 곱해서 이전 AB, BC의 결과값과 합한다.

그렇다면 (ABC)D의 경우는?

(a) + (b, c, d) + 비용

(a, b) + (c, d) + 비용

(a, b, c) + (d) + 비용

이 셋 중의 최소값이 A, B, C, D행렬 곱의 최소값이 된다.

이것을 나타내면

min(ABCD, min(A) + min(BCD) + 합치는 비용(A행 x A열 x D열),

min(AB) + min(CD) + 합치는 비용(A행 x B열 x D열),

min(ABC) + min(D) + 합치는 비용(A행 x C열 x D열),

)

이 된다.

이것을 식으로 나타내보면,

min(dp[j][j+i](최대값), dp[첫행렬 위치][k] + dp[k+1][마지막 행렬 위치] + (matrix[첫행렬 위치][0] * matrix[k][1] * matrix[마지막행렬 위치][1] )

import sys
input = sys.stdin.readline
# 4
# 5 3
# 3 2
# 2 6
# 6 3
N = int(input())
S = [list(map(int, input().split())) for i in range(N)]
# dp[i][j] -> i부터 j까지 최솟값
# (a) + (b, c, d) + 비용
# (a, b) + (c, d) + 비용
# (a, b, c) + (d) + 비용 이 중 최소값이 된다.
dp = [[0] * N for i in range(N)]
for i in range(1, N):
    for j in range(N - i):
        x = j + i
        # 문제에서 제시된 연산 횟수 최대값
        dp[j][x] = 2 ** 32

        for k in range(j, x):
            # n = 4일때
            # dp[0][1] -> dp[1][2] -> dp[2][3] -> dp[0][2] -> dp[1][3] -> dp[0][3]
            dp[j][x] = min(dp[j][x], dp[j][k] + dp[k + 1][x] + S[j][0] * S[k][1] * S[x][1])

print(dp[0][N - 1])

관심있을 포스팅