Saturday, February 13, 2016

My First A* Search Algorithm: 9Puzzle


The principal idea of the A* search algorithm is to have a cost function which consists of two parts: the known cost and the forecast cost.  The known cost is the distance between the current status and the initial status; the forecast cost is a heuristic function that estimate the distance between the current status and the target status.

The A* search algorithm is the generalization of BFS algorithm and the pure greedy algorithm. If we set known cost to zero, we get back to the pure greedy algorithm which only relies on the heuristic function; on the other hand, if we set forecast cost to zero, we will get the classic BFS algorithm.

The key component of A* search algorithm is to construct the heuristic function. If the heuristic function nevers overestimate the cost/distance, the algorithm is guaranteed to find the optimal solution.


The following code is the implementation of the 15puzzle solver.  As in the A* search algorithm, every time it expands the nodes the program should select a node with minimum cost, we will use a priority queue to store all the information of the status. To avoid the re-visit issue, we use a set to keep track of visited node.


Here are some note of the implemenation:
  1. The heuristic function plays a crucial rule in the performance of the A* search algorithm.
  2. One of the weaknesses of the A* search algorithm and other similar algorithms is that the algorithm need to remember all the historical path. This is the price we need to pay to become less greedy.
  3. Though pointer is not highly recommended in C++, it is still very useful and flexible. In particular, when we are using container.

[main.cpp]

#include <utility>
#include <algorithm>
#include <queue>
#include <vector>
#include <iostream>
#include <math.h>
#include <set>
#include <stack>

// project headers
#include "Status.h"
#include "mUtility.h"

using namespace std;



int main(int argc, const char * argv[]) {

    vector<int> init_arr = {1,2,3,4,0,5,6,7,8};
    vector<int> target_arr = {1,0,8,6,4,5,7,2,3};

    Status* initStatus = new Status(init_arr);
    Status* targetStatus = new Status(target_arr);

    // set the target status
    initStatus->targetStatus = targetStatus;


    // start the A* search
    AStarSearch(*initStatus, *targetStatus);

    // release memory
    delete initStatus;
    delete targetStatus;

    cout << endl << "the end" << endl;

    return 0;
}  

 [Status.h]

/*
 * Status.h
 *
 *  Created on: Feb 13, 2016
 *      Author: ruikun
 */

#ifndef STATUS_H_
#define STATUS_H_


#include <vector>
#include <stdlib.h>

using namespace std;


class Status
{
public:
    Status();
    explicit Status(vector<int>& arr);


    vector<int> arr;      // the table is kept in this array
    int code;      // the code is used as a hash value and in our case, it is used for each status
    int existingCost;
    int forecastCost;    // this is the value of heuristic function
    int totalCost;
    int posEmpty;     // position of the empty cell in the table
    int sizeOfSide;
    const Status *prev;    // pointer to parent status. It is used when printing the solution path
    Status *targetStatus;

    void computeCode();
    void update();     // TODO: think of merge computeCode() and update()
    Status* moveToEmpty(int pos) const;   // implement a move. This will generate a new status.

};




#endif /* STATUS_H_ */


[Status.cpp]

/*
 * Status.cpp
 *
 *  Created on: Feb 13, 2016
 *      Author: ruikun
 */

#include "Status.h"
#include <stdlib.h>
#include <math.h>


Status::Status()
: arr()
, code(-1)
, existingCost(0)
, forecastCost(0)
, totalCost(0)
, posEmpty(-1)
, sizeOfSide(0)
, targetStatus(NULL)
, prev(NULL)
{}



Status::Status(vector<int>& inArr)
: code(-1)
, existingCost(0)
, forecastCost(0)
, totalCost(0)
, posEmpty(-1)
, sizeOfSide(sqrt(inArr.size()))
, targetStatus(NULL)
, prev(NULL)
{

    // copy the array of configuration
    for (int i = 0; i < inArr.size(); ++i) {arr.push_back(inArr[i]);}


    // get the position of empty square
    for (int i = 0; i < arr.size(); ++i) {
        if (arr[i] == 0) {
            posEmpty = i;
            break;
        }
    }

    computeCode();
}


void  Status::computeCode() {
    int s = 0;
    int factor = 1;
    for (int i = 0; i < arr.size(); ++i) {
        s += arr[i] * factor;
        factor *= 10;
    }
    code = s;
}

void Status::update(){

    // calculate the forecast cost
    int count = 0;
    for (int i = 0; i < arr.size(); ++i)
        if (arr[i] != targetStatus->arr[i])
            ++count;

    forecastCost = count;

    totalCost = existingCost + forecastCost;

}


Status* Status::moveToEmpty(int pos) const{
    vector<int> _arr(arr.begin(), arr.end());
    _arr[this->posEmpty] = arr[pos];
    _arr[pos] = 0;

    Status* oStatus = new Status(_arr);
    oStatus->existingCost = this->existingCost + 1;
    oStatus->targetStatus = this->targetStatus;
    oStatus->prev = this;
    oStatus->computeCode();
    oStatus->update();

    return oStatus;
}



[mUtility.h]

/*
 * mUtility.h
 *
 *  Created on: Feb 13, 2016
 *      Author: ruikun
 */

#ifndef MUTILITY_H_
#define MUTILITY_H_


#include <queue>
#include <stack>
#include <set>
#include <algorithm>
#include <iostream>
#include "Status.h"

using namespace std;

// The following functions will compute the position of adjacent cells of the given position
// If the position does not exist, the function will return -1

int up(int pos, int sizeOfSide);
int down(int pos, int sizeOfSide);
int left(int pos, int sizeOfSide);
int right(int pos, int sizeOfSide);


void AStarSearch(Status& initStatus, Status& targetStatus);


class myCompare
{
public:
 bool operator()(Status* lhs, Status* rhs);
};

typedef priority_queue<Status*, vector<Status*>, myCompare> TYPE_queue;


bool addNewStatus(const Status& inStatus, TYPE_queue & inQueue, set<int>& record, int targetCode);



#endif /* MUTILITY_H_ */

[mUtility.cpp]


/*
 * mUtility.cpp
 *
 *  Created on: Feb 13, 2016
 *      Author: ruikun
 */


#include "mUtility.h"
#include <functional>

using namespace std;




bool myCompare::operator() (Status* lhs, Status* rhs) {return lhs->totalCost > rhs->totalCost;}



void AStarSearch(Status& initStatus, Status& targetStatus)
{
    TYPE_queue mQueue;
    set<int> mRecord;       // keep record of visited status

    // initialize
    mQueue.push(&initStatus);

    mRecord.insert(initStatus.code);

    int targetCode = targetStatus.code;



    cout << "searching for solution..." << endl;



    while (!mQueue.empty() && find(mRecord.begin(), mRecord.end(), targetCode) == mRecord.end()){
     cout << "the size of the queue: " << mQueue.size() << endl;
     cout << "the size of the record: " << mRecord.size() << endl;
        // pop the status with minimum cost
        const Status * ptrStatus = mQueue.top();


        for (vector<int>::size_type i = 0; i < ptrStatus->arr.size(); ++i){
         cout << ptrStatus->arr[i] << " " ;
        }
        cout << endl;

        cout << "totalCost: " << ptrStatus->totalCost << endl;
        cout << endl;

        if (addNewStatus(*ptrStatus, mQueue, mRecord, targetCode)) {
         break;
        }

        mQueue.pop();
    }

    /* TODO:
     * (1) need to handle the case where initStatus == targetStatus
     * (2) the current code does not clear all the memory
     */

    // clear the memory
    while (!mQueue.empty()) {
     delete mQueue.top();
     mQueue.pop();

    }

}


bool addNewStatus(const Status& inStatus, TYPE_queue& inQueue, set<int>& record, int targetCode)
{
    int posMove;

    int posEmpty = inStatus.posEmpty;
    int sizeOfSide = inStatus.sizeOfSide;

    bool find_solution = false;
    Status* sol = NULL;
    Status* tmp;

    //TODO: there are duplications of codes here.

    if ( (posMove = up(posEmpty, sizeOfSide)) != -1) {
     tmp = inStatus.moveToEmpty(posMove);
     if (find(record.begin(), record.end(), tmp->code) == record.end()){
      inQueue.push(tmp);
      if (tmp->code == targetCode) {
       find_solution = true;
       sol = tmp;
      } else {
       record.insert(tmp->code);
      }
     }
    }


    if ( (posMove = down(posEmpty, sizeOfSide)) != -1){
     tmp = inStatus.moveToEmpty(posMove);
     if (find(record.begin(), record.end(), tmp->code) == record.end()){
      inQueue.push(tmp);
      if (tmp->code == targetCode) {
       find_solution = true;
       sol = tmp;
      } else {
       record.insert(tmp->code);
      }
     }
    }


    if ( (posMove = left(posEmpty, sizeOfSide)) != -1){
     tmp = inStatus.moveToEmpty(posMove);
     if (find(record.begin(), record.end(), tmp->code) == record.end()){
      inQueue.push(tmp);
      if (tmp->code == targetCode) {
       find_solution = true;
       sol = tmp;
      } else {
       record.insert(tmp->code);
      }
     }
    }


    if ( (posMove = right(posEmpty, sizeOfSide)) != -1){
     tmp = inStatus.moveToEmpty(posMove);
     if (find(record.begin(), record.end(), tmp->code) == record.end()){
      inQueue.push(tmp);
      if (tmp->code == targetCode) {
       find_solution = true;
       sol = tmp;
      } else {
       record.insert(tmp->code);
      }
     }
    }


    // test if we reach the target status
    if (sol) { // reach the target status

        stack<const Status *> mStack;

        const Status *ptr = sol;

        while (ptr != NULL) {
            mStack.push(ptr);
            ptr = ptr->prev;
        }

        while (!mStack.empty()){
            ptr = mStack.top();
            mStack.pop();

            for (vector<int>::size_type i =0; i < ptr->arr.size(); ++i) {
             if (i % ptr->sizeOfSide == 0)
              cout << endl;
             int tmp = ptr->arr[i];

             if (tmp == 0) {
              cout << 'X' << " ";
             } else {
              cout << tmp << " ";
             }
            }

            cout << endl;
        }
    }
    return find_solution;

}






int up(int pos, int sizeOfSide)
{

    int row = pos / sizeOfSide;
    int col = pos % sizeOfSide;

    if (row == 0) {
        return -1;
    } else {
        return (row - 1) * sizeOfSide + col;
    }

}


int down(int pos, int sizeOfSide)
{
    int row = pos / sizeOfSide;
    int col = pos % sizeOfSide;

    if (row == sizeOfSide - 1) {
        return -1;
    } else {
        return (row + 1) * sizeOfSide + col;
    }
}

int left(int pos, int sizeOfSide)
{
    int col = pos % sizeOfSide;

    if (col == 0) {
        return -1;
    } else {
        return pos - 1;
    }
}

int right(int pos, int sizeOfSide)
{
    int col = pos % sizeOfSide;
    if (col == sizeOfSide - 1) {
        return -1;
    } else {
        return pos + 1;
    }
}



bool compare(const Status&  l_status, const Status&  r_status)
{
    return l_status.totalCost < r_status.totalCost;
}





No comments:

Post a Comment