#include <string.h>
#include <libpq-fe.h>

void usage();
void setup(PGconn *conn);
void check(PGconn *conn);
void run(PGconn *conn);
void cleanup(PGconn *conn);
void reset(PGconn *conn);
int getRand(int max);
void doQry(char *qry, PGconn *conn);
char *doScalar(char *qry, PGconn *conn);

int main(int argc, char *argv[])
{
    // Check the syntax
    if (argc < 5) usage();

    srand(time(0));

    // Open the db connection
    char connstr[1024];
    PGconn *conn;

    sprintf(connstr, "hostaddr=%s dbname=%s user=%s", argv[2], argv[3], argv[4]);
    if (argc == 4) {
        strcat(connstr, " password=");
        strcat(connstr, argv[5]);
    }

    conn = PQconnectdb(connstr);

    if (PQstatus(conn) == CONNECTION_BAD) {
        fprintf(stderr, "Couldn't connect to the database!!\n\n");
        exit(1);
    }

    if (strcmp(argv[1], "setup") == 0)
        setup(conn);
    else if (strcmp(argv[1], "cleanup") == 0)
        cleanup(conn);
    else if (strcmp(argv[1], "reset") == 0)
        reset(conn);
    else if (strcmp(argv[1], "run") == 0)
        run(conn);
    else if (strcmp(argv[1], "check") == 0)
        check(conn);
    else
        usage();

    // Cleanup & exit
    PQfinish(conn);

  return 0;
}

void usage()
{
    fprintf(stderr, "Usage: pgtest [setup|run|cleanup|check] <Server IP Address> <DB Name> <Username> [<Password>]\n\n");
    exit(1);
}

void setup(PGconn *conn)
{
    char qry[1024];
    int x, v1, v2, v3, v4, v5, total, average;

    sprintf(qry, "CREATE TABLE pgtest(id serial, v1 int4, v2 int4, v3 int4, v4 int4, v5 int4, total int4, average int4);"); 
    doQry(qry, conn);

    sprintf(qry, "BEGIN;"); 
    doQry(qry, conn);
 
    for (x=0;x<10000;x++) {
        v1 = getRand(10000);
        v2 = getRand(10000);
        v3 = getRand(10000);
        v4 = getRand(10000);
        v5 = getRand(10000);
        total = v1 + v2 + v3 + v4 + v5;
        average = total / 5;

        sprintf(qry, "INSERT INTO pgtest (v1, v2, v3, v4, v5, total, average) VALUES (%d, %d, %d, %d, %d, %d, %d);", v1, v2, v3, v4, v5, total, average); 
        doQry(qry, conn);
    }    

    sprintf(qry, "COMMIT;"); 
    doQry(qry, conn);
}

void check(PGconn *conn)
{

/* We need to check for corrupt of duplicated records, so:    */
/* SELECT COUNT(DISTINCT(id)) FROM pgtest SB 10000            */
/* SELECT MAX(id) SB 10000                                    */
/* SELECT MIN(id) SB 1                                        */
/* SELECT COUNT(*) WHERE v1+v2+v3+v4+v5 != total SB 0         */
/* SELECT COUNT(*) WHERE (v1+v2+v3+v4+v5) / 5 != average SB 0 */

    char qry[1024], *res;
    int err = 0;

    fprintf(stderr, "Checking data consistency...\n");

    // Check for dups
    sprintf(qry, "SELECT COUNT(DISTINCT(id)) FROM pgtest;");
    res = doScalar(qry, conn);
    if (strcmp(res, "10000") != 0) {
        fprintf(stderr, "DISTINCT CHECK - Duplicate or missing rows detected (%s)!!\n", res);
        err = 1;
    }

    // Check for dups
    sprintf(qry, "SELECT COUNT(*) FROM pgtest;");
    res = doScalar(qry, conn);
    if (strcmp(res, "10000") != 0) {
        fprintf(stderr, "COUNT CHECK - Duplicate or missing rows detected (%s)!!\n", res);
        err = 1;
    }

    // Check maximum ID
    sprintf(qry, "SELECT MAX(id) FROM pgtest;");
    res = doScalar(qry, conn);
    if (strcmp(res, "10000") != 0) {
        fprintf(stderr, "MAX CHECK - Invalid row ID detected (> 10000)!!\n");
        err = 1;
    }

    // Check minimum ID
    sprintf(qry, "SELECT MIN(id) FROM pgtest;");
    res = doScalar(qry, conn);
    if (strcmp(res, "1") != 0) {
        fprintf(stderr, "MIN CHECK - Invalid row ID detected (< 1)!!\n");
        err = 1;
    }

    // Check totals
    sprintf(qry, "SELECT COUNT(*) FROM pgtest WHERE (v1 + v2 + v3 + v4 + v5) != total;");
    res = doScalar(qry, conn);
    if (strcmp(res, "0") != 0) {
        fprintf(stderr, "TOTAL CHECK - %s invalid total(s) detected!!\n", res);
        err = 1;
    }

    // Check averages
    sprintf(qry, "SELECT COUNT(*) FROM pgtest WHERE ((v1 + v2 + v3 + v4 + v5) / 5) != average;");
    res = doScalar(qry, conn);
    if (strcmp(res, "0") != 0) {
        fprintf(stderr, "AVERAGE CHECK - %s invalid average(s) detected!!\n", res);
        err = 1;
    }

    if (err > 0) {
        fprintf(stderr, "1 or more errors detected - exiting!\n\n");
        PQfinish(conn);
        exit(1);
    }
}

void run(PGconn *conn)
{
    char qry[1024];
    int x, runs, id, v1, v2, v3, v4, v5, total, average;

    while(1) {

        // Check data
        check(conn);

        runs = getRand(1000);
        fprintf(stderr, "Starting run of %d updates.\n", runs);

        sprintf(qry, "BEGIN;"); 
        doQry(qry, conn);

        for (x=0;x<runs;x++) {
            id = getRand(10000);
            v1 = getRand(10000);
            v2 = getRand(10000);
            v3 = getRand(10000);
            v4 = getRand(10000);
            v5 = getRand(10000);
            total = v1 + v2 + v3 + v4 + v5;
            average = total / 5;

            sprintf(qry, "UPDATE pgtest SET v1 = %d, v2 = %d, v3 = %d, v4 = %d, v5 = %d, total = %d, average = %d WHERE id = %d;", v1, v2, v3, v4, v5, total, average, id); 
            doQry(qry, conn);
        }    

        sprintf(qry, "COMMIT;"); 
        doQry(qry, conn);
    }
}

void cleanup(PGconn *conn)
{
    char qry[1024];

    sprintf(qry, "DROP TABLE pgtest; DROP SEQUENCE pgtest_id_seq;"); 
    doQry(qry, conn);
}


void reset(PGconn *conn)
{
    cleanup(conn);
    setup(conn);
    PQfinish(conn);
    exit(0);
}

int getRand(int max)
{
    // Get a random number
    return (1 + rand() % max);
}

void doQry(char *qry, PGconn *conn)
{
    PGresult *qryRes;
    int res;

    qryRes = PQexec(conn, qry);
    res = PQresultStatus(qryRes);

    if (res != PGRES_TUPLES_OK && res != PGRES_COMMAND_OK) {
        fprintf(stderr, PQerrorMessage(conn));
        PQclear(qryRes);
        PQfinish(conn);
        exit(1);
    }
   
    PQclear(qryRes);
}

char *doScalar(char *qry, PGconn *conn)
{
    PGresult *qryRes;
    char *val;
    int res;

    qryRes = PQexec(conn, qry);
    res = PQresultStatus(qryRes);

    if (res != PGRES_TUPLES_OK) {
        fprintf(stderr, PQerrorMessage(conn));
        PQclear(qryRes);
        PQfinish(conn);
        exit(1);
    }
   
    val = PQgetvalue(qryRes, 0, 0);
    PQclear(qryRes);

    return val;
}
