Wednesday, February 17, 2016

Union-Find Datastructure C++ Implementaiont

Union-Find data structure, also known as the disjoint-set data structure is used to keep track of different disjoint subset of elements.

The following sources are good starting points to learn more about this data structure:

  1. Union-Find Lecture Note
  2. Data Structures & Algorithm Analysis in C++


Here is a simple C++ implementation of the union-find data structure.

[main.cpp]


#include <iostream>
#include <utility>


#include "UnionFind.h"

using namespace std;

int main()
{
 UnionFind uf;
 uf.add(8);

 uf.unionNode(1,2);
 uf.unionNode(2,3);
 uf.unionNode(4,5);
 uf.unionNode(0,5);

 cout << "print the parent of each node: " << endl;
 for (int i = 0; i < uf._count; ++i) {cout << uf._parent[i] << " ";}
 cout << endl;

 vector<int> out = uf.getLabels();
 cout << "print the label of each node: " << endl;
 for (vector<int>::size_type i = 0; i < out.size(); ++i) {cout << out[i] << " ";}
 cout << endl;

 return 0;
}


[UnionFind.h]


#ifndef UNIONFIND_H_
#define UNIONFIND_H_

#include <vector>
#include <list>

using namespace std;

class UnionFind
{
public:
 UnionFind();
 vector<int> _parent;
 vector<int> _rank;
 int _count;

 int find(int node);
 void unionNode(int node1, int node2);
 void unionForest(vector<int> parentArr);
 void add(int repeat=1);
 vector<int> getLabels();
};


#endif /* UNIONFIND_H_ */

[UnionFind.cpp]


#include <stdexcept>
#include "UnionFind.h"

using namespace std;

UnionFind::UnionFind(): _parent(), _rank(), _count(0)
{}

void UnionFind::add(int repeat)
{
 for (int i = 0; i < repeat; ++i) {
  ++_count;
  _parent.push_back(-1);
  _rank.push_back(0);
 }
}


int UnionFind::find(int node)
{
 // if the node is the root of the class, return -1 immediately
 if (_parent[node] == -1)
  return -1;

 int currentNode = node;

 list<int> path;
 int sum_rank = 0;

 while (_parent[currentNode] != -1) {
  path.push_back(currentNode);
  sum_rank += _rank[currentNode];
  currentNode = _parent[currentNode];
 }

 // set parent of the nodes in the path to the root
 for (list<int>::const_iterator it=path.begin(); it != path.end(); ++it) {_parent[*it] = currentNode;}
 _rank[currentNode] = sum_rank;

 return currentNode;
}


void UnionFind::unionNode(int node1, int node2)
{
 if (node1 == node2) return; // if the nodes are the same, do nothing

 int root1 = find(node1);
 int root2 = find(node2);

 if (root1 == node2 || root2 == node1) return ; // if node1 and node2 are in the same class(tree), do nothing




 if (_rank[node1] < _rank[node2]) {
  _parent[node2] = node1;
  _rank[node1] += _rank[node2] + 1;
 } else {
  _parent[node1] = node2;
  _rank[node2] += _rank[node1] + 1;
 }

}


void UnionFind::unionForest(vector<int> parentArr)
{
 // check if parentArr represents the same nodes
 if (parentArr.size() != (vector<int>::size_type)_count)
  throw invalid_argument("The number of nodes does not match");

 int _root;
 for (vector<int>::size_type i = 0; i < parentArr.size(); ++i) {
  _root = parentArr[i];
  if (_root != -1) unionNode(i, _root);

 }
}


vector<int> UnionFind::getLabels()
{
 int cLabel = 0;
 int _root = -1;
 vector<int> label(_count, -1);

 for (int i=0; i < _count; ++i) {
  _root = find(i);

  if (_root == -1) {     // if node(i) i the root
   if (label[i] == -1)   // if node(i) does not have a label
    label[i] = cLabel++;
  } else if (label[_root] == -1) {
   label[_root] = cLabel;
   label[i] = cLabel;
   ++cLabel;
  } else {
   label[i] = label[_root];
  }
 }

 return label;
}

No comments:

Post a Comment