Wednesday, February 17, 2016

Two-pass algorithm for binary image

Two-pass algorithm for binary image is one of the applications of union-find data structure. The idea of the algorithm is described in this lecture note (page 10/24)


Here is a simple implementation of the two-pass algorithm on binary image. The algorithm will add a label to each connected component in the image. The connectivity is tested based on the definition of 4-connectivity.

Sample input:

0 1 0 0 0 0 1
1 1 0 1 0 1 1
0 0 0 1 1 0 0

Sample output:

0 1 0 0 0 0 2
1 1 0 3 0 2 2
0 0 0 3 3 0 0



[twopass.h]


#ifndef TWOPASS_H_
#define TWOPASS_H_


#include <vector>
#include "UnionFind.h"

typedef vector< vector<int> > matrix;
matrix twoPassAlgorithm(const matrix & inMatrix);

#endif /* TWOPASS_H_ */


[twopass.cpp]

/*
 * twopass.cpp
 *
 *  Created on: Feb 17, 2016
 *      Author: ruikun
 */

#include "twopass.h"

matrix twoPassAlgorithm(const matrix & inMatrix)
{
 int nrow = inMatrix.size();
 int ncol = inMatrix[0].size();

 matrix label(nrow, vector<int>(ncol, -1)); // -1 indicates that there is no label in the cell


 UnionFind uf;
 // first pass
 int label_up;
 int label_left;

 for (int i = 0; i < nrow; ++i) {
  for (int j = 0; j < ncol; ++j) {

   if (inMatrix[i][j] == 1) {
    // boundary case: we can put boundary case outside the loop
    if (i == 0 && j == 0) {
     uf.add();
     label[0][0] = uf.getLatestLabel();
     continue;
    }

    if (i == 0) {
     if (label[0][j-1] != -1) {
      label[0][j] = label[0][j-1];
     } else {
      uf.add();
      label[0][j] = uf.getLatestLabel();
     }
     continue;
    }

    if (j == 0) {
     if (label[i-1][0] != -1) {
      label[i][0] = label[i-1][0];
     } else {
      uf.add();
      label[i][0] = uf.getLatestLabel();
     }
     continue;
    }

    // normal case

    // if only one of upper and left cell has label, then copy the label
    label_up = label[i-1][j];
    label_left = label[i][j-1];

    if (label_up != -1 && label_left == -1) {
     label[i][j] = label_up;
    } else if (label_up == -1 && label_left != -1) {
     label[i][j] = label_left;
    } else if (label_up != -1 && label_left != 1 && label_up == label_left) {
     label[i][j] = label_up;
    } else if (label_up != label_left) {
     label[i][j] = label_up;
     uf.unionNode(label_up, label_left);
    } else {  // in this case, neither of upper cell nor left cell has a label
     uf.add();
     label[i][j] = uf.getLatestLabel(); // get new label

    }
   }
  }
 }

 // re-labeling
 vector<int> classNumber(uf.getLabels());

 int _label;
 for (int i = 0; i < nrow; ++i) {
  for (int j = 0; j < ncol; ++j){

   _label = label[i][j];
   label[i][j] = (_label == -1 ? -1 : classNumber[label[i][j]]) + 1;
  }
 }


 return label;
}

Union-Find Datastructure C++ Implementaiont

Union-Find data structure, also known as the disjoint-set data structure is used to keep track of different disjoint subset of elements.

The following sources are good starting points to learn more about this data structure:

  1. Union-Find Lecture Note
  2. Data Structures & Algorithm Analysis in C++


Here is a simple C++ implementation of the union-find data structure.

[main.cpp]


#include <iostream>
#include <utility>


#include "UnionFind.h"

using namespace std;

int main()
{
 UnionFind uf;
 uf.add(8);

 uf.unionNode(1,2);
 uf.unionNode(2,3);
 uf.unionNode(4,5);
 uf.unionNode(0,5);

 cout << "print the parent of each node: " << endl;
 for (int i = 0; i < uf._count; ++i) {cout << uf._parent[i] << " ";}
 cout << endl;

 vector<int> out = uf.getLabels();
 cout << "print the label of each node: " << endl;
 for (vector<int>::size_type i = 0; i < out.size(); ++i) {cout << out[i] << " ";}
 cout << endl;

 return 0;
}


[UnionFind.h]


#ifndef UNIONFIND_H_
#define UNIONFIND_H_

#include <vector>
#include <list>

using namespace std;

class UnionFind
{
public:
 UnionFind();
 vector<int> _parent;
 vector<int> _rank;
 int _count;

 int find(int node);
 void unionNode(int node1, int node2);
 void unionForest(vector<int> parentArr);
 void add(int repeat=1);
 vector<int> getLabels();
};


#endif /* UNIONFIND_H_ */

[UnionFind.cpp]


#include <stdexcept>
#include "UnionFind.h"

using namespace std;

UnionFind::UnionFind(): _parent(), _rank(), _count(0)
{}

void UnionFind::add(int repeat)
{
 for (int i = 0; i < repeat; ++i) {
  ++_count;
  _parent.push_back(-1);
  _rank.push_back(0);
 }
}


int UnionFind::find(int node)
{
 // if the node is the root of the class, return -1 immediately
 if (_parent[node] == -1)
  return -1;

 int currentNode = node;

 list<int> path;
 int sum_rank = 0;

 while (_parent[currentNode] != -1) {
  path.push_back(currentNode);
  sum_rank += _rank[currentNode];
  currentNode = _parent[currentNode];
 }

 // set parent of the nodes in the path to the root
 for (list<int>::const_iterator it=path.begin(); it != path.end(); ++it) {_parent[*it] = currentNode;}
 _rank[currentNode] = sum_rank;

 return currentNode;
}


void UnionFind::unionNode(int node1, int node2)
{
 if (node1 == node2) return; // if the nodes are the same, do nothing

 int root1 = find(node1);
 int root2 = find(node2);

 if (root1 == node2 || root2 == node1) return ; // if node1 and node2 are in the same class(tree), do nothing




 if (_rank[node1] < _rank[node2]) {
  _parent[node2] = node1;
  _rank[node1] += _rank[node2] + 1;
 } else {
  _parent[node1] = node2;
  _rank[node2] += _rank[node1] + 1;
 }

}


void UnionFind::unionForest(vector<int> parentArr)
{
 // check if parentArr represents the same nodes
 if (parentArr.size() != (vector<int>::size_type)_count)
  throw invalid_argument("The number of nodes does not match");

 int _root;
 for (vector<int>::size_type i = 0; i < parentArr.size(); ++i) {
  _root = parentArr[i];
  if (_root != -1) unionNode(i, _root);

 }
}


vector<int> UnionFind::getLabels()
{
 int cLabel = 0;
 int _root = -1;
 vector<int> label(_count, -1);

 for (int i=0; i < _count; ++i) {
  _root = find(i);

  if (_root == -1) {     // if node(i) i the root
   if (label[i] == -1)   // if node(i) does not have a label
    label[i] = cLabel++;
  } else if (label[_root] == -1) {
   label[_root] = cLabel;
   label[i] = cLabel;
   ++cLabel;
  } else {
   label[i] = label[_root];
  }
 }

 return label;
}

Saturday, February 13, 2016

My First A* Search Algorithm: 9Puzzle


The principal idea of the A* search algorithm is to have a cost function which consists of two parts: the known cost and the forecast cost.  The known cost is the distance between the current status and the initial status; the forecast cost is a heuristic function that estimate the distance between the current status and the target status.

The A* search algorithm is the generalization of BFS algorithm and the pure greedy algorithm. If we set known cost to zero, we get back to the pure greedy algorithm which only relies on the heuristic function; on the other hand, if we set forecast cost to zero, we will get the classic BFS algorithm.

The key component of A* search algorithm is to construct the heuristic function. If the heuristic function nevers overestimate the cost/distance, the algorithm is guaranteed to find the optimal solution.


The following code is the implementation of the 15puzzle solver.  As in the A* search algorithm, every time it expands the nodes the program should select a node with minimum cost, we will use a priority queue to store all the information of the status. To avoid the re-visit issue, we use a set to keep track of visited node.


Here are some note of the implemenation:
  1. The heuristic function plays a crucial rule in the performance of the A* search algorithm.
  2. One of the weaknesses of the A* search algorithm and other similar algorithms is that the algorithm need to remember all the historical path. This is the price we need to pay to become less greedy.
  3. Though pointer is not highly recommended in C++, it is still very useful and flexible. In particular, when we are using container.

[main.cpp]

#include <utility>
#include <algorithm>
#include <queue>
#include <vector>
#include <iostream>
#include <math.h>
#include <set>
#include <stack>

// project headers
#include "Status.h"
#include "mUtility.h"

using namespace std;



int main(int argc, const char * argv[]) {

    vector<int> init_arr = {1,2,3,4,0,5,6,7,8};
    vector<int> target_arr = {1,0,8,6,4,5,7,2,3};

    Status* initStatus = new Status(init_arr);
    Status* targetStatus = new Status(target_arr);

    // set the target status
    initStatus->targetStatus = targetStatus;


    // start the A* search
    AStarSearch(*initStatus, *targetStatus);

    // release memory
    delete initStatus;
    delete targetStatus;

    cout << endl << "the end" << endl;

    return 0;
}  

 [Status.h]

/*
 * Status.h
 *
 *  Created on: Feb 13, 2016
 *      Author: ruikun
 */

#ifndef STATUS_H_
#define STATUS_H_


#include <vector>
#include <stdlib.h>

using namespace std;


class Status
{
public:
    Status();
    explicit Status(vector<int>& arr);


    vector<int> arr;      // the table is kept in this array
    int code;      // the code is used as a hash value and in our case, it is used for each status
    int existingCost;
    int forecastCost;    // this is the value of heuristic function
    int totalCost;
    int posEmpty;     // position of the empty cell in the table
    int sizeOfSide;
    const Status *prev;    // pointer to parent status. It is used when printing the solution path
    Status *targetStatus;

    void computeCode();
    void update();     // TODO: think of merge computeCode() and update()
    Status* moveToEmpty(int pos) const;   // implement a move. This will generate a new status.

};




#endif /* STATUS_H_ */


[Status.cpp]

/*
 * Status.cpp
 *
 *  Created on: Feb 13, 2016
 *      Author: ruikun
 */

#include "Status.h"
#include <stdlib.h>
#include <math.h>


Status::Status()
: arr()
, code(-1)
, existingCost(0)
, forecastCost(0)
, totalCost(0)
, posEmpty(-1)
, sizeOfSide(0)
, targetStatus(NULL)
, prev(NULL)
{}



Status::Status(vector<int>& inArr)
: code(-1)
, existingCost(0)
, forecastCost(0)
, totalCost(0)
, posEmpty(-1)
, sizeOfSide(sqrt(inArr.size()))
, targetStatus(NULL)
, prev(NULL)
{

    // copy the array of configuration
    for (int i = 0; i < inArr.size(); ++i) {arr.push_back(inArr[i]);}


    // get the position of empty square
    for (int i = 0; i < arr.size(); ++i) {
        if (arr[i] == 0) {
            posEmpty = i;
            break;
        }
    }

    computeCode();
}


void  Status::computeCode() {
    int s = 0;
    int factor = 1;
    for (int i = 0; i < arr.size(); ++i) {
        s += arr[i] * factor;
        factor *= 10;
    }
    code = s;
}

void Status::update(){

    // calculate the forecast cost
    int count = 0;
    for (int i = 0; i < arr.size(); ++i)
        if (arr[i] != targetStatus->arr[i])
            ++count;

    forecastCost = count;

    totalCost = existingCost + forecastCost;

}


Status* Status::moveToEmpty(int pos) const{
    vector<int> _arr(arr.begin(), arr.end());
    _arr[this->posEmpty] = arr[pos];
    _arr[pos] = 0;

    Status* oStatus = new Status(_arr);
    oStatus->existingCost = this->existingCost + 1;
    oStatus->targetStatus = this->targetStatus;
    oStatus->prev = this;
    oStatus->computeCode();
    oStatus->update();

    return oStatus;
}



[mUtility.h]

/*
 * mUtility.h
 *
 *  Created on: Feb 13, 2016
 *      Author: ruikun
 */

#ifndef MUTILITY_H_
#define MUTILITY_H_


#include <queue>
#include <stack>
#include <set>
#include <algorithm>
#include <iostream>
#include "Status.h"

using namespace std;

// The following functions will compute the position of adjacent cells of the given position
// If the position does not exist, the function will return -1

int up(int pos, int sizeOfSide);
int down(int pos, int sizeOfSide);
int left(int pos, int sizeOfSide);
int right(int pos, int sizeOfSide);


void AStarSearch(Status& initStatus, Status& targetStatus);


class myCompare
{
public:
 bool operator()(Status* lhs, Status* rhs);
};

typedef priority_queue<Status*, vector<Status*>, myCompare> TYPE_queue;


bool addNewStatus(const Status& inStatus, TYPE_queue & inQueue, set<int>& record, int targetCode);



#endif /* MUTILITY_H_ */

[mUtility.cpp]


/*
 * mUtility.cpp
 *
 *  Created on: Feb 13, 2016
 *      Author: ruikun
 */


#include "mUtility.h"
#include <functional>

using namespace std;




bool myCompare::operator() (Status* lhs, Status* rhs) {return lhs->totalCost > rhs->totalCost;}



void AStarSearch(Status& initStatus, Status& targetStatus)
{
    TYPE_queue mQueue;
    set<int> mRecord;       // keep record of visited status

    // initialize
    mQueue.push(&initStatus);

    mRecord.insert(initStatus.code);

    int targetCode = targetStatus.code;



    cout << "searching for solution..." << endl;



    while (!mQueue.empty() && find(mRecord.begin(), mRecord.end(), targetCode) == mRecord.end()){
     cout << "the size of the queue: " << mQueue.size() << endl;
     cout << "the size of the record: " << mRecord.size() << endl;
        // pop the status with minimum cost
        const Status * ptrStatus = mQueue.top();


        for (vector<int>::size_type i = 0; i < ptrStatus->arr.size(); ++i){
         cout << ptrStatus->arr[i] << " " ;
        }
        cout << endl;

        cout << "totalCost: " << ptrStatus->totalCost << endl;
        cout << endl;

        if (addNewStatus(*ptrStatus, mQueue, mRecord, targetCode)) {
         break;
        }

        mQueue.pop();
    }

    /* TODO:
     * (1) need to handle the case where initStatus == targetStatus
     * (2) the current code does not clear all the memory
     */

    // clear the memory
    while (!mQueue.empty()) {
     delete mQueue.top();
     mQueue.pop();

    }

}


bool addNewStatus(const Status& inStatus, TYPE_queue& inQueue, set<int>& record, int targetCode)
{
    int posMove;

    int posEmpty = inStatus.posEmpty;
    int sizeOfSide = inStatus.sizeOfSide;

    bool find_solution = false;
    Status* sol = NULL;
    Status* tmp;

    //TODO: there are duplications of codes here.

    if ( (posMove = up(posEmpty, sizeOfSide)) != -1) {
     tmp = inStatus.moveToEmpty(posMove);
     if (find(record.begin(), record.end(), tmp->code) == record.end()){
      inQueue.push(tmp);
      if (tmp->code == targetCode) {
       find_solution = true;
       sol = tmp;
      } else {
       record.insert(tmp->code);
      }
     }
    }


    if ( (posMove = down(posEmpty, sizeOfSide)) != -1){
     tmp = inStatus.moveToEmpty(posMove);
     if (find(record.begin(), record.end(), tmp->code) == record.end()){
      inQueue.push(tmp);
      if (tmp->code == targetCode) {
       find_solution = true;
       sol = tmp;
      } else {
       record.insert(tmp->code);
      }
     }
    }


    if ( (posMove = left(posEmpty, sizeOfSide)) != -1){
     tmp = inStatus.moveToEmpty(posMove);
     if (find(record.begin(), record.end(), tmp->code) == record.end()){
      inQueue.push(tmp);
      if (tmp->code == targetCode) {
       find_solution = true;
       sol = tmp;
      } else {
       record.insert(tmp->code);
      }
     }
    }


    if ( (posMove = right(posEmpty, sizeOfSide)) != -1){
     tmp = inStatus.moveToEmpty(posMove);
     if (find(record.begin(), record.end(), tmp->code) == record.end()){
      inQueue.push(tmp);
      if (tmp->code == targetCode) {
       find_solution = true;
       sol = tmp;
      } else {
       record.insert(tmp->code);
      }
     }
    }


    // test if we reach the target status
    if (sol) { // reach the target status

        stack<const Status *> mStack;

        const Status *ptr = sol;

        while (ptr != NULL) {
            mStack.push(ptr);
            ptr = ptr->prev;
        }

        while (!mStack.empty()){
            ptr = mStack.top();
            mStack.pop();

            for (vector<int>::size_type i =0; i < ptr->arr.size(); ++i) {
             if (i % ptr->sizeOfSide == 0)
              cout << endl;
             int tmp = ptr->arr[i];

             if (tmp == 0) {
              cout << 'X' << " ";
             } else {
              cout << tmp << " ";
             }
            }

            cout << endl;
        }
    }
    return find_solution;

}






int up(int pos, int sizeOfSide)
{

    int row = pos / sizeOfSide;
    int col = pos % sizeOfSide;

    if (row == 0) {
        return -1;
    } else {
        return (row - 1) * sizeOfSide + col;
    }

}


int down(int pos, int sizeOfSide)
{
    int row = pos / sizeOfSide;
    int col = pos % sizeOfSide;

    if (row == sizeOfSide - 1) {
        return -1;
    } else {
        return (row + 1) * sizeOfSide + col;
    }
}

int left(int pos, int sizeOfSide)
{
    int col = pos % sizeOfSide;

    if (col == 0) {
        return -1;
    } else {
        return pos - 1;
    }
}

int right(int pos, int sizeOfSide)
{
    int col = pos % sizeOfSide;
    if (col == sizeOfSide - 1) {
        return -1;
    } else {
        return pos + 1;
    }
}



bool compare(const Status&  l_status, const Status&  r_status)
{
    return l_status.totalCost < r_status.totalCost;
}