본문으로 바로가기

[BOJ 10830] 행렬 제곱(C++)

category Algorithms/Math 2021. 10. 8. 14:49

행렬 제곱(Gold 4)

문제

전체 문제 보기

 

10830번: 행렬 제곱

크기가 N*N인 행렬 A가 주어진다. 이때, A의 B제곱을 구하는 프로그램을 작성하시오. 수가 매우 커질 수 있으니, A^B의 각 원소를 1,000으로 나눈 나머지를 출력한다.

www.acmicpc.net

 

접근법

이번 문제는 행렬의 거듭제곱을 최적화하는 문제입니다. 우선 행렬 간의 곱셈은 다음과 같이 3중 for문을 통해서 계산할 수 있습니다.

using Matrix = vector<vector<int>>;
Matrix operator*(const Matrix& lhs, const Matrix& rhs)
{
    int N = static_cast<int>(lhs.size());
    Matrix ret(N);

    for (int i = 0; i < N; ++i)
    {
        ret[i].resize(N);
        for (int j = 0; j < N; ++j)
        {
            for (int k = 0; k < N; ++k)
            {
                ret[i][j] += lhs[i][k] * rhs[k][j];
                ret[i][j] %= MOD; // MOD : 1,000
            }
        }
    }
    return ret;
}

일반적으로 임의의 행렬 \(M\)에 대해서 \(M^n\)은 \(M\)을 \(n\)번 곱해서 구할 수 있습니다. 하지만 이번 문제에서 제곱수 \(n\)은 최대 \(100,000,000,000\) 로 굉장히 큰 수입니다. 이는 곱셈의 제곱을 최적화하는 방법과 같은 방법으로 접근할 수 있습니다.

 

만약 \(M^{100}\)을 구해야한다면 \(M\)을 \(100\)번 곱하는 대신 \(M^{50}\) * \(M^{50}\) 으로 곱하기 한 번만에 계산할 수 있습니다. 마찬가지로 \(M^{50}\)는 또한 \(M^{25}\) * \(M^{25}\) 곱하기 한 번으로 계산할 수 있습니다.

 

만약 \(n\)이 홀수인 경우에는 \(n-1\)은 반드시 짝수이기 때문에 \(M^{n - 1 }\) * \(M\)으로 계산할 수 있습니다.
따라서, 정리하자면

  • \(n\)이 홀수인 경우 : \(M^n = M^{n-1} * M \)
  • \(n\)이 짝수인 경우 : \(M^n = M^{n/2} * M^{n/2} \)

따라서 다음과 같이 행렬의 거듭제곱을 계산할 수 있습니다.

Matrix Power(const Matrix& mat, long long pow)
{
    if (pow == 1) return mat;

    if (pow % 2 == 0) // 짝수일 경우
    {
        Matrix sqrtOfRet = Power(mat, pow / 2);
        return sqrtOfRet * sqrtOfRet;
    }
    else // 홀수일 경우
    {
        Matrix ret = Power(mat, pow - 1) * mat;
        return ret;
    }
}

성능 평가

행렬의 거듭제곱을 \(O(\mbox{log}B)\) 으로 최적화하였습니다. 

전체 구현 코드

#include <iostream>
#include <vector>
#include <algorithm>
#define endl '\n'
using namespace std;
const long long MOD = 1'000;
using Matrix = vector<vector<int>>;
Matrix operator*(const Matrix& lhs, const Matrix& rhs)
{
    int N = static_cast<int>(lhs.size());
    Matrix ret(N);

    for (int i = 0; i < N; ++i)
    {
        ret[i].resize(N);
        for (int j = 0; j < N; ++j)
        {
            for (int k = 0; k < N; ++k)
            {
                ret[i][j] += lhs[i][k] * rhs[k][j];
                ret[i][j] %= MOD;
            }
        }
    }
    return ret;
}

Matrix Power(const Matrix& mat, long long pow)
{
    if (pow == 1) return mat;

    if (pow % 2 == 0) // 짝수일 경우
    {
        Matrix sqrtOfRet = Power(mat, pow / 2);
        return sqrtOfRet * sqrtOfRet;
    }
    else // 홀수일 경우
    {
        Matrix ret = Power(mat, pow - 1) * mat;
        return ret;
    }
}
int main()
{
    //입출력 성능향상을 위한 설정
    ios_base::sync_with_stdio(false);
    cout.tie(NULL);
    cin.tie(NULL);

    int N;
    long long B;
    cin >> N >> B;

    Matrix mat(N);
    for (int i = 0; i < N; ++i)
    {
        mat[i].resize(N);
        for (int j = 0; j < N; ++j)
        {
            cin >> mat[i][j];
        }
    }

    Matrix answer = Power(mat, B);
    for (int i = 0; i < N; ++i)
    {
        for (int j = 0; j < N; ++j)
        {
            cout << answer[i][j] % MOD << " ";
        }
        cout << endl;
    }


    return 0;
}