공부/Data Structure

AVL Tree

sudo 2021. 8. 6. 19:28

지난번에 작성한 Binary Search Tree에 이어서 자가 균형 트리의 한 종류인 AVL(Adelson-Velsky and Landis) Tree를 작성해보았다. 자가 균형 트리란 편향 트리(skewed tree)가 되지 않도록, 트리의 균형을 스스로 맞추는 트리이다. 당연히 편향 트리가 되면 트리로 탐색을 하는데 비효율적이고 탐색, 삽입의 평균적인 시간 복잡도 logN을 보장하지 못한다.

 

구현한 부분에 있어서 특징은, 코드에서 볼 수 있듯이, Rebalance함수가 재귀적으로 계속 부모 노드를 타고 올라가면서 호출된다는 점이다. insert함수에서는 새로 삽입된 NewNode를 타고 올라가고, erase에서는 지우는 동작이 끝나고 지운 노드를 대체할 노드(LeftMax or RightMin)의 부모 노드(Parent)를 타고 올라가면서 Rebalance를 재귀적을 호출한다.

 

또한 LL, LR, RL, RR case에서 기준 노드가 바뀌어서, RotateLeft, RotateRight 함수의 리턴으로 바뀐 기준 노드로 할당해주는 것 또한 특징이다.

#pragma once

template <typename KEY, typename VALUE>
class CAVLTreeNode
{
	template <typename KEY, typename VALUE>
	friend class CAVLTree;

	template <typename KEY, typename VALUE>
	friend class CAVLTreeIterator;

private:
	CAVLTreeNode()	:
		m_Parent(nullptr),
		m_Left(nullptr),
		m_Right(nullptr),
		m_Next(nullptr),
		m_Prev(nullptr)
	{

	}
	~CAVLTreeNode()
	{

	}

private:
	CAVLTreeNode<KEY, VALUE>* m_Parent;
	CAVLTreeNode<KEY, VALUE>* m_Left;
	CAVLTreeNode<KEY, VALUE>* m_Right;
	CAVLTreeNode<KEY, VALUE>* m_Next;
	CAVLTreeNode<KEY, VALUE>* m_Prev;


	// main에서 iterator->first, iterator->second로 접근하려고
	// public으로 선언
public:
	KEY first;
	VALUE second;

public:	
	bool IsParent()	const
	{
		return m_Parent != nullptr;
	}

	KEY GetParentKey()	const
	{
		return m_Parent->first;
	}

	VALUE GetParentValue()	const
	{
		return m_Parent->second;
	}

	bool IsLeft()	const
	{
		return m_Left != nullptr;
	}

	KEY GetLeftKey()	const
	{
		return m_Left->first;
	}

	VALUE GetLeftValue()	const
	{
		return m_Left->second;
	}

	bool IsRight()	const
	{
		return m_Right != nullptr;
	}

	KEY GetRightKey()	const
	{
		return m_Right->first;
	}

	VALUE GetRightValue()	const
	{
		return m_Right->second;
	}

};

template <typename KEY, typename VALUE>
class CAVLTreeIterator
{
	template <typename KEY, typename VALUE>
	friend class CAVLTree;

private:
	CAVLTreeNode<KEY, VALUE>* m_Node;

// CAVLTreeNode와 달리 main에서도 사용을 해야하는데
// private이 되면 main에서 접근을 못하므로
// 꼭 public으로 해주기
public:
	CAVLTreeIterator()
		: m_Node(nullptr)
	{

	}

	~CAVLTreeIterator()
	{

	}

public:
	void operator ++ ()
	{
		m_Node = m_Node->m_Next;
	}

	void operator ++ (int)
	{
		m_Node = m_Node->m_Next;
	}

	void operator -- ()
	{
		m_Node = m_Node->m_Prev;
	}

	void operator -- (int)
	{
		m_Node = m_Node->m_Prev;
	}

	bool operator == (const CAVLTreeIterator<KEY, VALUE>& iter) const
	{
		// return *this == iter 이렇게 쓰면 또 이 != 연산자 오버로딩 호출하므로 쓰면 안된다.
		return m_Node == iter.m_Node;
	}

	bool operator != (const CAVLTreeIterator<KEY, VALUE>& iter) const
	{
		// return *this != iter 이렇게 쓰면 또 이 != 연산자 오버로딩 호출하므로 쓰면 안된다.
		return m_Node != iter.m_Node;
	}

	bool operator == (const CAVLTreeNode<KEY, VALUE>* Node)	const
	{
		return m_Node == Node;
	}

	bool operator != (const CAVLTreeNode<KEY, VALUE>* Node)	const
	{
		return m_Node != Node;
	}

	CAVLTreeNode<KEY, VALUE>* operator -> ()	const
	{
		return m_Node;
	}

};

template <typename KEY, typename VALUE>
class CAVLTree
{
public:
	CAVLTree()
	{
		m_Root = nullptr;
		m_Size = 0;

		m_Begin = new NODE;
		m_End = new NODE;

		m_Begin->m_Next = m_End;
		m_End->m_Prev = m_Begin;
	}

	~CAVLTree()
	{
		PNODE	DeleteNode = m_Begin;

		while (DeleteNode)
		{
			PNODE	Next = DeleteNode->m_Next;

			delete	DeleteNode;

			DeleteNode = Next;
		}
	}

public:
	typedef CAVLTreeIterator<KEY, VALUE> iterator;

private:
	typedef CAVLTreeNode<KEY, VALUE>* PNODE;
	typedef CAVLTreeNode<KEY, VALUE> NODE;

private:
	PNODE m_Root;
	PNODE m_Begin;
	PNODE m_End;
	int m_Size;

public:
	int size() const
	{
		return m_Size;
	}

	bool empty()	const
	{
		return m_Size == 0;
	}

	iterator end() const
	{
		iterator iter;
		iter.m_Node = m_End;
		return iter;
	}

	iterator begin() const
	{
		iterator iter;
		iter.m_Node = m_Begin->m_Next;
		return iter;
	}

	void clear()
	{
		PNODE	Node = m_Begin->m_Next;

		while (Node != m_End)
		{
			PNODE	Next = Node->m_Next;

			delete	Node;

			Node = Next;
		}

		m_Begin->m_Next = m_End;
		m_End->m_Prev = m_Begin;

		m_Size = 0;

		m_Root = nullptr;
	}


	void PreOrder(void (*pFunc)(const KEY&, const VALUE&))
	{
		PreOrder(pFunc, m_Root);
	}

	void InOrder(void (*pFunc)(const KEY&, const VALUE&))
	{
		InOrder(pFunc, m_Root);
	}

	void PostOrder(void (*pFunc)(const KEY&, const VALUE&))
	{
		PostOrder(pFunc, m_Root);
	}

	// 없는 노드를 찾으려하면 end() 리턴
	iterator Find(const KEY& key) const
	{
		return Find(key, m_Root);
	}

	void insert(const KEY& key, const KEY& value)
	{
		// 처음 노드를 삽입하는 경우
		if (!m_Root)
		{
			m_Root = new NODE;
			m_Root->first = key;
			m_Root->second = value;

			m_Begin->m_Next = m_Root;
			m_Root->m_Prev = m_Begin;

			m_End->m_Prev = m_Root;
			m_Root->m_Next = m_End;
		}
		else
		{
			insert(key, value, m_Root);
		}
		++m_Size;
	}

	iterator erase(const KEY& key)
	{
		iterator iter = Find(key);


		if (iter == end() || iter == m_Begin)
			return iter;

		return erase(iter);
	}

	// 지운 노드의 다음 노드의 iterator를 리턴 
	iterator erase(const iterator& iter)
	{
		// 지우려는 노드가 리프노드인 경우
		if (iter.m_Node->m_Left == nullptr && iter.m_Node->m_Right == nullptr)
		{
			// 트리에 루트노드 하나밖에 없는 경우
			if (iter.m_Node == m_Root)
			{
				delete m_Root;
				m_Root = nullptr;

				m_Begin->m_Next = m_End;
				m_End->m_Prev = m_Begin;

				--m_Size;

				return end();
			}

			PNODE Parent = iter->m_Parent;
			if (Parent->m_Left == iter.m_Node)
			{
				Parent->m_Left = nullptr;
			}

			else
			{
				Parent->m_Right = nullptr;
			}

			PNODE Prev = iter->m_Prev;
			PNODE Next = iter->m_Next;

			Prev->m_Next = Next;
			Next->m_Prev = Prev;

			//iter.m_Node->m_Parent 상대로 Rebalnce?

			delete iter.m_Node;

			--m_Size;

			Rebalance(Parent);

			iterator	result;
			result.m_Node = Next;

			return result;
		}

		// 지우려는 노드가 리프노드가 아닌 경우
		else
		{
			// 지운 노드 기준 왼쪽에서 가장 큰 노드로
			// 지운 노드 자리를 대체하자
			if (iter->m_Left)
			{
				PNODE LeftMax = FindMax(iter->m_Left);

				// iter인자는 const iterator& 인데 이렇게 바꿔줘도 되나?
				iter.m_Node->first = LeftMax->first;
				iter.m_Node->second = LeftMax->second;

				PNODE Parent = LeftMax->m_Parent;
				// 지우려는 노드의 왼쪽에서 가장 큰 노드의 왼쪽 자식
				// 오른쪽 자식은 있을리 없다
				// 없으면 nullptr일 것이다
				PNODE LeftMaxLeftChild = LeftMax->m_Left;

				if (Parent->m_Left == LeftMax)
				{
					Parent->m_Left = LeftMaxLeftChild;
				}
				else
				{
					Parent->m_Right = LeftMaxLeftChild;
				}

				if (LeftMaxLeftChild)
				{
					LeftMaxLeftChild->m_Parent = Parent;
				}

				PNODE Prev = LeftMax->m_Prev;
				PNODE Next = LeftMax->m_Next;

				Prev->m_Next = Next;
				Next->m_Prev = Prev;

				delete LeftMax;

				--m_Size;

				Rebalance(Parent);

				iterator result;
				result.m_Node = Next;
				return result;
			}

			// 오른쪽에서 가장 작은 노드를 찾아서
			// 지운 노드 자리를 대체하자
			else
			{
				PNODE RightMin = FindMin(iter.m_Node->m_Right);

				iter.m_Node->first = RightMin->first;
				iter.m_Node->second = RightMin->second;

				PNODE Parent = RightMin->m_Parent;

				// 지우려는 노드의 오른쪽에서 가장 작은 노드의 오른쪽 자식
				// 왼쪽 자식은 있을리 없다
				// 없으면 nullptr일 것이다
				PNODE RightMinRightChild = RightMin->m_Right;

				if (Parent->m_Left == RightMin)
				{
					Parent->m_Left = RightMinRightChild;
				}
				else
				{
					Parent->m_Right = RightMinRightChild;
				}

				if (RightMinRightChild)
				{
					RightMinRightChild->m_Parent = Parent;
				}

				PNODE Prev = RightMin->m_Prev;
				PNODE Next = RightMin->m_Next;

				Prev->m_Next = Next;
				Next->m_Prev = Prev;

				delete RightMin;

				--m_Size;

				Rebalance(Parent);

				iterator result;
				result.m_Node = Next;
				return result;
			}
		}
	}

	void PreOrder(void (*pFunc)(const KEY&, const VALUE&), PNODE Node)
	{
		if (!Node)
			return;

		pFunc(Node->first, Node->second);
		PreOrder(pFunc, Node->m_Left);
		PreOrder(pFunc, Node->m_Right);
	}

	void InOrder(void(*pFunc)(const KEY&, const VALUE&), PNODE Node)
	{
		if (!Node)
			return;

		InOrder(pFunc, Node->m_Left);
		pFunc(Node->first, Node->second);
		InOrder(pFunc, Node->m_Right);
	}

	void PostOrder(void(*pFunc)(const KEY&, const VALUE&), PNODE Node)
	{
		if (!Node)
			return;

		PostOrder(pFunc, Node->m_Left);
		PostOrder(pFunc, Node->m_Right);
		pFunc(Node->first, Node->second);
	}

private:
	PNODE FindMax(PNODE Node) const
	{
		if (Node->m_Right)
		{
			return FindMax(Node->m_Right);
		}
		return Node;
	}

	PNODE FindMin(PNODE Node) const
	{
		if (Node->m_Left)
		{
			return FindMin(Node->m_Left);
		}
		return Node;
	}


	// 새로 삽입한 노드를 리턴
	PNODE insert(const KEY& key, const KEY& value, PNODE node)
	{
		if (node->first > key)
		{
			if (node->m_Left)
			{
				return insert(key, value, node->m_Left);
			}

			PNODE NewNode = new NODE;
			NewNode->first = key;
			NewNode->second = value;

			NewNode->m_Parent = node;
			node->m_Left = NewNode;

			PNODE Prev = node->m_Prev;

			Prev->m_Next = NewNode;
			NewNode->m_Prev = Prev;

			node->m_Prev = NewNode;
			NewNode->m_Next = node;

			Rebalance(NewNode);

			return NewNode;
		}

		else
		{
			if (node->m_Right)
			{
				return insert(key, value, node->m_Right);
			}

			PNODE NewNode = new NODE;

			NewNode->first = key;
			NewNode->second = value;

			NewNode->m_Parent = node;
			node->m_Right = NewNode;

			PNODE Next = node->m_Next;

			Next->m_Prev = NewNode;
			NewNode->m_Next = Next;

			NewNode->m_Prev = node;
			node->m_Next = NewNode;

			Rebalance(NewNode);

			return NewNode;
		}
	}

	iterator Find(const KEY& key, PNODE node) const
	{
		if (!node)
		{
			return end();
		}

		else if (node == m_End || node == m_Begin)
		{
			return end();
		}

		else if (node->first == key)
		{
			iterator iter;
			iter.m_Node = node;
			return iter;
		}
		
		if (node->first > key)
		{
			return Find(key, node->m_Left);
		}

		return Find(key, node->m_Right);

	}

	int Height(PNODE Node)
	{
		if (!Node)
		{
			return 0;
		}

		int LeftHeight = Height(Node->m_Left);
		int RightHeight = Height(Node->m_Right);

		int h = LeftHeight > RightHeight ? LeftHeight : RightHeight;

		return h + 1;
	}

	PNODE RotateLeft(PNODE Node)
	{
		if (!Node)
			return nullptr;

		PNODE Parent = Node->m_Parent;

		PNODE RightChild = Node->m_Right;
		PNODE RightLeftChild = Node->m_Right->m_Left;

		if (RightLeftChild)
		{
			RightLeftChild->m_Parent = Node;
		}

		if (Parent)
		{
			if (Parent->m_Left == Node)
			{
				Parent->m_Left = RightChild;
			}

			else
			{
				Parent->m_Right = RightChild;
			}
		}
		// Parent가 없다는 의미는 Node가 Root였다는 의미
		else
		{
			m_Root = RightChild;
		}

		RightChild->m_Parent = Parent;
		
		RightChild->m_Left = Node;
		Node->m_Parent = RightChild;

		Node->m_Right = RightLeftChild;

		return RightChild;
	}

	PNODE RotateRight(PNODE Node)
	{
		if (!Node)
			return nullptr;

		PNODE Parent = Node->m_Parent;

		PNODE LeftChild = Node->m_Left;
		PNODE LeftRightChild = Node->m_Left->m_Right;

		if (LeftRightChild)
		{
			LeftRightChild->m_Parent = Node;
		}

		if (Parent)
		{
			if (Parent->m_Left == Node)
			{
				Parent->m_Left = LeftChild;
			}

			else
			{
				Parent->m_Right = LeftChild;
			}
		}
		// Parent가 없다는 의미는 Node가 Root였다는 의미
		else
		{
			m_Root = LeftChild;
		}

		LeftChild->m_Parent = Parent;

		LeftChild->m_Right = Node;
		Node->m_Parent = LeftChild;

		Node->m_Left = LeftRightChild;

		return LeftChild;
	}

	void Rebalance(PNODE Node)
	{
		if (!Node)
		{
			return;
		}

		int BF = BalanceFactor(Node);

		// 균형이 무너진 경우
		if (BF < -1 || BF > 1)
		{
			// RR or RL Case
			if (BF < -1)
			{
				// RL Case
				if (BalanceFactor(Node->m_Right) > 0)
				{
					RotateRight(Node->m_Right);
					Node = RotateLeft(Node);
				}

				// RR Case
				else
				{
					Node = RotateLeft(Node);
				}
			}

			// LL or LR Case
			else
			{
				// LL Case
				if (BalanceFactor(Node->m_Left) > 0)
				{
					Node = RotateRight(Node);
				}

				// LR Case
				else
				{
					RotateLeft(Node->m_Left);
					Node = RotateRight(Node);
				}

			}
		}

		// 재귀적으로 계속 올라가면서 Rebalance호출
		Rebalance(Node->m_Parent);
	}

	int BalanceFactor(PNODE Node)
	{
		return Height(Node->m_Left) - Height(Node->m_Right);
	}

};

 

 

#include <iostream>
#include "AVLTree.h"
#include <crtdbg.h>


int main()
{
	_CrtSetDbgFlag(_CRTDBG_ALLOC_MEM_DF | _CRTDBG_LEAK_CHECK_DF);

	CAVLTree<int, int>	tree;

	for (int i = 0; i < 10; ++i)
	{
		tree.insert(i, i);
	}

	tree.erase(5);
	tree.erase(4);
	tree.erase(6);

	CAVLTree<int, int>::iterator	iter;

	for (iter = tree.begin(); iter != tree.end(); ++iter)
	{
		std::cout << "Key : " << iter->first << " Value : " << iter->second << std::endl;
		std::cout << "ParentKey : ";

		if (iter->IsParent())
			std::cout << iter->GetParentKey();

		else
			std::cout << "없음";

		std::cout << " Parent Value : ";

		if (iter->IsParent())
			std::cout << iter->GetParentValue() << std::endl;

		else
			std::cout << "없음" << std::endl;

		std::cout << "LeftKey : ";

		if (iter->IsLeft())
			std::cout << iter->GetLeftKey();

		else
			std::cout << "없음";

		std::cout << " Left Value : ";

		if (iter->IsLeft())
			std::cout << iter->GetLeftValue() << std::endl;

		else
			std::cout << "없음" << std::endl;

		std::cout << "RightKey : ";

		if (iter->IsRight())
			std::cout << iter->GetRightKey();

		else
			std::cout << "없음";

		std::cout << " Right Value : ";

		if (iter->IsRight())
			std::cout << iter->GetRightValue() << std::endl;

		else
			std::cout << "없음" << std::endl;

		std::cout << std::endl;
	}

	return 0;
}

 

출력 결과와 완성된 트리의 모양