본문으로 바로가기

분할 정복을 활용한 Merge Sort 구현

접근법

Merge Sort는 분할 정복을 활용한 대표적인 알고리즘 중 하나입니다. 분할 정복을 연습하기 위해서 Merge Sort를 구현해 보았습니다. Merge Sort의 원리는 다음과 같습니다. 하나의 배열을 원소가 하나가 남을 때까지 양분(분할)합니다. 재귀적으로 분할된 원소를 정렬하며 다시 병합하여 온전히 정렬된 배열을 만듭니다.
Merge Sort의 가장 큰 특징은 안정 정렬이라는 것입니다. 안정 정렬이란 똑같은 원소라면 그 상대 위치가 변하지 않는다는 뜻입니다. 예를 들어 아래와 같은 데이터 타입이 있다고 가정해 봅시다.

struct data{
    int key;
    char name;
}

위 데이터 타입을 가진 데이터들이 아래와 같이 정렬되지 않은 상태로 있습니다.

[3, a] [5, b] [6, c] [5, d] [4, e] [8, f] [5, g] [2, h]

정렬되지 않은 배열에 key값이 5인 원소가 3개 있습니다. 이를 key값을 기준으로 정렬하면 아래와 같습니다.

[2, h] [3, a] [4, e] [5, b] [5, d] [5, g] [6, c] [8, f]

여기서 주목할 점은 중복된 key값 5를 가진 원소 3개의 상대위치는 변하지 않습니다. 이를 안정 정렬이라고 부르며 병합 정렬은 안정 정렬을 보장하는 정렬입니다. 반면 Quick sort는 안정 정렬을 보장하지 않습니다.

 

두 번째 특징은 Quick sort의 공간 복잡도가 \(O(1)\)인 반면 병합 정렬의 공간 복잡도는 \(O(n)\) 이라는 점입니다. 이는 병합 과정에서 모든 원소를 정렬하여 저장할 임시 배열이 필요하기 때문입니다. 그리고 마지막 특징은 시간 복잡도는 Quick sort와 마찬가지로 \(O(n\mbox{log}n)\) 이라는 점입니다. 그래서 \(O(n^2)\) 시간 복잡도를 가지는 Insert sort, Selection sort, Buble sort 보다는 매우 빠른 속도를 보입니다. 하지만 일반적인 경우에서는 Merge Sort 보다 Quick sort가 더 빠르다고 알려져 있습니다.

 

C++에서 병합정렬을 사용하기 위해서는 algorithm 헤더에 있는 std::stable_sort를 사용할 수 있습니다.

병합 정렬 구현

아래와 같이 코드로 구현할 수 있습니다. BOJ11004번 문제를 통해서 정확성 및 성능 테스트를 진행하였습니다.

#include <iostream>
#include <vector>

#define endl '\n'
using namespace std;

void Merge(vector<int>& arr, const int start, const int end)
{
    const int mid = (start + end) / 2;
    int right = (start + end) / 2 + 1;
    int left = start;

    vector<int> temp(end - start + 1);
    for (int i = 0; i < temp.size(); ++i)
    {
        if (left <= mid && right <= end)
        {
            if (arr[left] < arr[right])
                temp[i] = arr[left++];
            else temp[i] = arr[right++];
        }
        else
        {
            if (left <= mid) temp[i] = arr[left++];
            else temp[i] = arr[right++];
        }
    }

    for (int i = 0; i < temp.size(); ++i)
    {
        arr[i + start] = temp[i];
    }
}


void MergeSort(vector<int>& arr, const int start, const int end)
{
    if (start == end) return;
    const int mid = (start + end) / 2;


    MergeSort(arr, start, mid);
    MergeSort(arr, mid + 1, end);

    Merge(arr, start, end);
}

int main() {
    // 입출력 성능 향상을 위한 설정
    ios_base::sync_with_stdio(false);
    cout.tie(NULL);
    cin.tie(NULL);

    int N, K; // N(1 ≤ N ≤ 5,000,000), K (1 ≤ K ≤ N)
    cin >> N >> K;

    vector<int> arr(N);
    for (int i = 0; i < N; ++i)
    {
        cin >> arr[i];
    }

    MergeSort(arr, 0, N - 1);

    cout << arr[K - 1];

    return 0;
}