크기가 N×M인 행렬 A와 M×K인 B를 곱할 때 필요한 곱셈 연산의 수는 총 N×M×K번이다. 이때 행렬 N개를 곱하는데 필요한 곱셈의 연산 수는 행렬을 곱하는 순서에 따라 따라지게 된다. 즉, A의 크기가 5×3이고, B의 크기가 3×2, C의 크기가 2×6인 경우 A, B를 먼저 곱하거나 B, C를 먼저 곱하는 경우 각각 90번, 126번으로 곱 횟수가 다르다. 이때 행렬 N개의 크기가 주어지고, 모든 행렬을 곱하는데 필요한 곱셈 연산 횟수의 최솟값을 구하려면 다음과 같은 알고리즘을 통해 구할 수 있다.
우리가 원하는 것은 '최소 연산 횟수'이다. 이때 '최소 연산 횟수'를 구하기 위해서 A×B×C×D×E 연산 시 (각 부분 문제의 연산 횟수)+(마지막 연산 횟수)의 최솟값을 구하면 된다. 즉, A×B, B×C, ..., A×B×C, B×C×D, ..., A×B×C×D, B×C×D×E까지 모든 부분 연산의 최소 연산 횟수를 구하고 마지막으로 A×B×C×D×E의 최소 연산 횟수를 구하면 된다. 이를 위해서 우리는 Dynamic programming 알고리즘을 활용할 수 있다. A = (5, 3), B = (3, 2), C = (2, 6), D = (6, 4)일 때 최소 연산 횟수를 구하는 DP Table은 다음과 같다. 이때 각 테이블의 원소는 Table(a, b)인 경우 a~b까지의 행렬들을 곱하는 경우의 최소 연산 횟수를 의미한다.
0 | 30 (A×B) | 90 (A×B×C, 90 vs 126) | 118 (A×B×C×D, 210 vs 118 vs 132) |
0 | 0 | 36 (B×C) | 72 (B×C×D, 108 vs 72) |
0 | 0 | 0 | 48 (C×D) |
0 | 0 | 0 | 0 |
python 코드로는 다음과 같이 구현할 수 있다.
n = int(input()) # n: 행렬의 개수
matrix_list = list() # matrix_list: 행렬의 크기를 담을 리스트
for _ in range(n):
matrix = tuple(map(int, input().split()))
matrix_list.append(matrix)
dp = [[0 for i in range(n)] for i in range(n)] # N × N dp table 선언
for len_multi in range(1, n):
# len_multi: 곱셈 길이 [Ex) A×B의 경우 len_multiply=1이다.]
for start_matrix in range(0, n-len_multi):
# start_matrix: 곱셈을 시작할 행렬의 번호 [Ex) 곱셈을 시작할 행렬이 A인 경우 start_matrix=0이다.]
end_matrix = start_matrix + len_multi # end_matrix: 곱셈에서 마지막 행렬의 인덱스
dp[start_matrix][end_matrix] = float("inf") # 이후 최소 곱셈 횟수를 저장하기 위해 0으로 저장되어 있는 값을 무한대로 변경한다.
for center_matrix in range(start_matrix, end_matrix):
# 행렬 곱셈 횟수를 구하는 연산이 a × b × c라고 할 때 (a×b, b×c 두 행렬을 곱하는 경우)
multi_row = matrix_list[start_matrix][0] # multi_row: a
multi_cen = matrix_list[center_matrix][1] # multi_cen: b
multi_col = matrix_list[end_matrix][1] # multi_col: c
temp = dp[start_matrix][center_matrix]+dp[center_matrix+1][end_matrix]+multi_row*multi_cen*multi_col # temp: 해당 경우의 결과값
dp[start_matrix][end_matrix] = min(dp[start_matrix][end_matrix], temp) # temp와 현재 저장된 값을 비교하여 더 작은 값을 저장한다.
print(dp[0][-1]) # 결과값을 출력한다.
먼저 1~n-1까지의 곱셈 길이를 가지는 경우를 모두 고려해야 한다. 이에 해당 시간 복잡도는 O(n)이다. 그리고 현재 탐색하는 곱셈 길이를 L이라 할 때 곱셈의 시작 행렬이 될 수 있는 행렬의 개수는 0~n-L-1개의 경우(O(n))를 탐색해야 한다. 그리고 각 경우마다 각 부분 문제를 고려하여 최소 곱셈 횟수를 탐색하는 데 걸리는 시간 복잡도가 O(n)이므로 총 시간 복잡도는 O(n³)이 된다.
python 코드로 제출하면 시간 초과가 발생하니 pypy3로 제출해야 한다.
'Algorithm > Dynamic Programming' 카테고리의 다른 글
[백준] 26156: 나락도 락이다 (python 파이썬) (0) | 2023.03.13 |
---|---|
[백준] 12865: 평범한 배낭 (python 파이썬) (0) | 2022.11.21 |
[백준] 11404: 플로이드 (python 파이썬) (0) | 2022.11.14 |