#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <time.h>
#include <limits.h>

//#define NULL ((void *) 0)
typedef uint32_t uint32;
typedef int32_t int32;
typedef uint64_t uint64;
typedef int64_t int64;
#define BITS_PER_BITMAPWORD 64
typedef uint64 bitmapword;	  /* must be an unsigned type */
typedef int64 signedbitmapword; /* must be the matching signed type */

#define WORDNUM(x)  ((x) / BITS_PER_BITMAPWORD)
#define BITNUM(x)   ((x) % BITS_PER_BITMAPWORD)

#ifdef __GNUC__
#define likely(x)	__builtin_expect((x) != 0, 1)
#define unlikely(x) __builtin_expect((x) != 0, 0)
#else
#define likely(x)	((x) != 0)
#define unlikely(x) ((x) != 0)
#endif

typedef struct Bitmapset
{
	int		 nwords;		/* number of words in array */
	bitmapword  words[];	/* really [nwords] */
} Bitmapset;

static inline int
bmw_rightmost_one_pos(uint64 word)
{
	return __builtin_ctzll(word);
}

static inline int
bmw_leftmost_one_pos(uint64 word)
{
	return 63 - __builtin_clzll(word);
}

int
bms_next_member(const Bitmapset *a, int prevbit)
{
	int		 nwords;
	bitmapword  mask;

	if (a == NULL)
		return -2;

	nwords = a->nwords;
	prevbit++;
	mask = (~(bitmapword) 0) << BITNUM(prevbit);
	for (int wordnum = WORDNUM(prevbit); wordnum < nwords; wordnum++)
	{
		bitmapword  w = a->words[wordnum];

		/* ignore bits before prevbit */
		w &= mask;

		if (w != 0)
		{
			int		 result;

			result = wordnum * BITS_PER_BITMAPWORD;
			result += bmw_rightmost_one_pos(w);
			return result;
		}

		/* in subsequent words, consider all bits */
		mask = (~(bitmapword) 0);
	}
	return -2;
}

int
bms_next_member_patched(const Bitmapset *a, int prevbit)
{
	unsigned int currbit = prevbit;
	int			nwords;
	bitmapword	mask;

	if (a == NULL)
		return -2;
	nwords = a->nwords;

	/* use an unsigned int to avoid the risk that int overflows */
	currbit++;
	mask = (~(bitmapword) 0) << BITNUM(currbit);
	for (int wordnum = WORDNUM(currbit); wordnum < nwords; wordnum++)
	{
		bitmapword	w = a->words[wordnum];

		/* ignore bits before currbit */
		w &= mask;

		if (w != 0)
		{
			int			result;

			result = wordnum * BITS_PER_BITMAPWORD;
			result += bmw_rightmost_one_pos(w);
			return result;
		}

		/* in subsequent words, consider all bits */
		mask = (~(bitmapword) 0);
	}
	return -2;
}

int
bms_prev_member(const Bitmapset *a, int prevbit)
{
	int			ushiftbits;
	bitmapword	mask;

	/*
	 * If set is NULL or if there are no more bits to the right then we've
	 * nothing to do.
	 */
	if (a == NULL || prevbit == 0)
		return -2;

	/* transform -1 to the highest possible bit we could have set */
	if (prevbit == -1)
		prevbit = a->nwords * BITS_PER_BITMAPWORD - 1;
	else
		prevbit--;

	ushiftbits = BITS_PER_BITMAPWORD - (BITNUM(prevbit) + 1);
	mask = (~(bitmapword) 0) >> ushiftbits;
	for (int wordnum = WORDNUM(prevbit); wordnum >= 0; wordnum--)
	{
		bitmapword	w = a->words[wordnum];

		/* mask out bits left of prevbit */
		w &= mask;

		if (w != 0)
		{
			int			result;

			result = wordnum * BITS_PER_BITMAPWORD;
			result += bmw_leftmost_one_pos(w);
			return result;
		}

		/* in subsequent words, consider all bits */
		mask = (~(bitmapword) 0);
	}
	return -2;
}

int
bms_prev_member_patched(const Bitmapset *a, int prevbit)
{
	unsigned int currbit;
	int			ushiftbits;
	bitmapword	mask;


	/*
	 * If set is NULL or if there are no more bits to the right then we've
	 * nothing to do.
	 */
	if (a == NULL || prevbit == 0)
		return -2;

	/*
	 * Transform -1 to the highest possible bit we could have set.  We do this
	 * in unsigned math to avoid the risk of overflowing a signed int.
	 */
	if (prevbit < 0)
		currbit = (unsigned int) a->nwords * BITS_PER_BITMAPWORD - 1;
	else
		currbit = prevbit - 1;

	ushiftbits = BITS_PER_BITMAPWORD - (BITNUM(currbit) + 1);
	mask = (~(bitmapword) 0) >> ushiftbits;
	for (int wordnum = WORDNUM(currbit); wordnum >= 0; wordnum--)
	{
		bitmapword	w = a->words[wordnum];

		/* mask out bits left of currbit */
		w &= mask;

		if (w != 0)
		{
			int			result;

			result = wordnum * BITS_PER_BITMAPWORD;
			result += bmw_leftmost_one_pos(w);
			return result;
		}

		/* in subsequent words, consider all bits */
		mask = (~(bitmapword) 0);
	}
	return -2;
}


double get_time() {
	struct timespec ts;
	clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &ts);
	return ts.tv_sec + ts.tv_nsec * 1e-9;
}

Bitmapset *bms;


int main() {
	int words_to_alloc = 1; // Large set to bypass CPU cache slightly
	bms = malloc(sizeof(Bitmapset) + words_to_alloc * sizeof(bitmapword));
	bms->nwords = words_to_alloc;
	memset(bms->words, 0, words_to_alloc * sizeof(bitmapword));
	double start, end;
	int64 count = 0;

	/* Set a bit far into the set to force a long scan */
	bms->words[words_to_alloc - 1] |= 0xaf4;

	int iterations = 100000000;


	printf("Benchmarking %d bms_next_member iterations...\n", iterations);

	/* master */
	start = get_time();
	for (int i = 0; i < iterations; i++)
	{
		int j = -1;
		while ((j = bms_next_member(bms, j)) >= 0)
			count++;
	}
	end = get_time();
	printf("master: %.5f seconds\n", end - start);

	// Test David
	start = get_time();
	for (int i = 0; i < iterations; i++)
	{
		int j = -1;
		while ((j = bms_next_member_patched(bms, j)) >= 0)
			count++;
	}

	end = get_time();
	printf("Patched: %.5f seconds\n", end - start);

	printf("\nBenchmarking %d bms_prev_member iterations...\n", iterations);

	/* master */
	start = get_time();
	for (int i = 0; i < iterations; i++)
	{
		int j = -1;
		while ((j = bms_prev_member(bms, j)) >= 0)
			count++;
	}
	end = get_time();
	printf("master: %.5f seconds\n", end - start);

	// Test David
	start = get_time();
	for (int i = 0; i < iterations; i++)
	{
		int j = -1;
		while ((j = bms_prev_member_patched(bms, j)) >= 0)
			count++;
	}

	end = get_time();
	printf("Patched: %.5f seconds\n", end - start);

	printf("%ld\n", count);
	free(bms);
	return 0;
}