/* 
 * 
 */
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <ctype.h>
#include <errno.h>
#include <pthread.h>
#include <unistd.h>
#include "VMS_Implementations/Vthread_impl/VPThread.h"
#include "C_Libraries/Queue_impl/PrivateQueue.h"
#include "C_Libraries/DynArray/DynArray.h"
#include "C_Libraries/BestEffortMessaging/LossyCom.h"

#include <linux/perf_event.h>
#include <linux/prctl.h>
#include <sys/syscall.h>

#undef DEBUG
//#define DEBUG

//#define MEASURE_PERF

#if !defined(unix) && !defined(__unix__)
#ifdef __MACH__
#define unix		1
#define __unix__	1
#endif	/* __MACH__ */
#endif	/* unix */

/* find the appropriate way to define explicitly sized types */
/* for C99 or GNU libc (also mach's libc) we can use stdint.h */
#if (__STDC_VERSION__ >= 199900) || defined(__GLIBC__) || defined(__MACH__)
#include <stdint.h>
#elif defined(unix) || defined(__unix__)	/* some UNIX systems have them in sys/types.h */
#include <sys/types.h>
#elif defined(__WIN32__) || defined(WIN32)	/* the nameless one */
typedef unsigned __int8 uint8_t;
typedef unsigned __int32 uint32_t;
#endif	/* sized type detection */

/* provide a millisecond-resolution timer for each system */
#if defined(unix) || defined(__unix__)
#include <time.h>
#include <sys/time.h>
unsigned long get_msec(void) {
	static struct timeval timeval, first_timeval;

	gettimeofday(&timeval, 0);
	if(first_timeval.tv_sec == 0) {
		first_timeval = timeval;
		return 0;
	}
	return (timeval.tv_sec - first_timeval.tv_sec) * 1000 + (timeval.tv_usec - first_timeval.tv_usec) / 1000;
}
#elif defined(__WIN32__) || defined(WIN32)
#include <windows.h>
unsigned long get_msec(void) {
	return GetTickCount();
}
#else
//#error "I don't know how to measure time on your platform"
#endif

//======================== Defines =========================
typedef struct perfData measurement_t;
struct perfData{
    uint64 cycles;
    uint64 instructions;
};

const char *usage = {
	"Usage: msg_passing_test [options]\n"
	"  Starts threads equal to the number of cores and sends\n"
        "  messages to random receivers\n\n"
	"Options:\n"
	"  -n <num>   This specifies the number of sends done by each thread.\n"
	"  -h         this help screen\n\n"
};

/***************************
 * Barrier Implementation
 ***************************/

struct barrier_t
{
    int counter;
    int nthreads;
    int32 mutex;
    int32 cond;
    measurement_t endBarrierCycles;

};
typedef struct barrier_t barrier;

void inline barrier_init(barrier *barr, int nthreads, VirtProcr *animatingPr)
 {
   barr->counter = 0;
   barr->nthreads = nthreads;
   barr->mutex   = VPThread__make_mutex(animatingPr);
   barr->cond    = VPThread__make_cond(barr->mutex, animatingPr);
 }

int cycles_counter_main_fd;
void inline barrier_wait(barrier *barr, VirtProcr *animatingPr)
 { int i;

   VPThread__mutex_lock(barr->mutex, animatingPr);
   barr->counter++;
   if(barr->counter == barr->nthreads)
    { 
#ifdef MEASURE_PERF
      read(cycles_counter_main_fd, &(barr->endBarrierCycles.cycles), \
                sizeof(barr->endBarrierCycles.cycles));
#endif
       
      barr->counter = 0;
      for(i=0; i < barr->nthreads; i++)
         VPThread__cond_signal(barr->cond, animatingPr);
    }
   else
    { VPThread__cond_wait(barr->cond, animatingPr);
    }
   VPThread__mutex_unlock(barr->mutex, animatingPr);
 }


/**************************
 * Worker Parameters
 **************************/
typedef struct
 { struct barrier_t* barrier;
   uint64_t  totalWorkCycles;
   uint64_t  totalBadCycles;
   uint64_t  totalSyncCycles;
   uint64_t  totalBadSyncCycles;
   uint64     numGoodSyncs;
   uint64     numGoodTasks;
   uint64_t   coreID;
   lossyCom__endpoint_t* localEndpoint;
   lossyCom__exchange_t* centralMsgExchange;
   unsigned int receivedACKs;
   unsigned int broadcasterStatus;
   unsigned int terminate;
 }
WorkerParams;

typedef struct
 { measurement_t *startExeCycles;
   measurement_t *endExeCycles;
 }
BenchParams;

typedef struct
{
    lossyCom__endpointID_t receiverID;
    lossyCom__msgBody_t msg;
} savedMsg_t;

//======================== Globals =========================
char __ProgrammName[] = "overhead_test";
char __DataSet[255];

int num_msg_to_send;
size_t chunk_size = 0;

int cycles_counter_fd[NUM_CORES];
struct perf_event_attr* hw_event;

WorkerParams *workerParamsArray;

// init random number
   uint32_t seed1;
   uint32_t seed2;

//======================== App Code =========================
/*
 * Workload
 */

#define saveCyclesAndInstrs(core,cycles) do{     \
   int cycles_fd = cycles_counter_fd[core];             \
   int nread;                                           \
                                                        \
   nread = read(cycles_fd,&(cycles),sizeof(cycles));    \
   if(nread<0){                                         \
       perror("Error reading cycles counter");          \
       cycles = 0;                                      \
   }                                                    \
} while (0) //macro magic for scoping

extern inline uint32_t
randomNumber(uint32_t* seed1, uint32_t* seed2);

#define BROADCAST BROADCAST_ID                  
#define BROADCAST_ACK BROADCAST_ID-1
#define TERMINATE BROADCAST_ID-2

#define RECEIVING_BROADCAST 0
#define BROADCASTING 1
#define RECEIVING_ACK 2

/*
 * Message Handler Function
 */
void msgHandler(lossyCom__endpointID_t senderID, lossyCom__msgBody_t msg, void* data)
{
    WorkerParams* threadData = (WorkerParams*)data;
    lossyCom__endpoint_t* comEndpoint = threadData->localEndpoint;
    lossyCom__endpointID_t receiverID;
    
    if(msg == BROADCAST_ID) //answer broadcast message
    {
        lossyCom__sendMsg(comEndpoint, senderID, BROADCAST_ACK);
        return;
    }
    if(msg == (BROADCAST_ACK) && threadData->broadcasterStatus == RECEIVING_ACK)
    {
        threadData->receivedACKs++;
        if(threadData->receivedACKs == NUM_CORES-2)//chose next broadcaster
        {
            do{
                receiverID = randomNumber(&seed1, &seed2) % NUM_CORES;
            }while(receiverID == comEndpoint->endpointID);
            
            //send the receiverID to the receiver to notify him that he is next
            lossyCom__sendMsg(comEndpoint, receiverID, receiverID);
            threadData->broadcasterStatus = RECEIVING_BROADCAST;
        }
        return;
    }
    if(msg == TERMINATE) //termination message
    {
        printf("endpoint %d received termination request\n", comEndpoint->endpointID);
        threadData->terminate = TRUE;
        return;
    }
    //I'm the next broadcaster!
    threadData->broadcasterStatus = BROADCASTING;
}

unsigned int global_broadcast_counter;

double
worker_TLF(void* _params, VirtProcr* animatingPr)
 {
    unsigned int msgCounter;
   WorkerParams* params = (WorkerParams*)_params;
   unsigned int totalWorkCycles = 0, totalBadCycles = 0;
   unsigned int totalSyncCycles = 0, totalBadSyncCycles = 0;
   unsigned int workspace1=0, numGoodSyncs = 0, numGoodTasks = 0;
   double workspace2=0.0;
   
   //core 0 always starts
   params->broadcasterStatus = params->coreID==0?BROADCASTING:RECEIVING_BROADCAST;
   
   /*
   int32 privateMutex = VPThread__make_mutex(animatingPr);
   
   int cpuid = sched_getcpu();
   
   measurement_t startWorkload, endWorkload, startWorkload2, endWorkload2;
   uint64 numCycles;
    */
#ifdef MEASURE_PERF
          saveCyclesAndInstrs(cpuid,startWorkload.cycles);
#endif
     
   //initialize endpoint for communication
   lossyCom__endpoint_t comEndpoint;
   params->localEndpoint = &comEndpoint;
   lossyCom__initialize_endpoint(&comEndpoint, 
                                 params->centralMsgExchange,
                                 params->coreID,
                                 msgHandler,
                                 params);
   
   msgCounter = 0;
   while(msgCounter <= num_msg_to_send)
   {
       if(params->broadcasterStatus == BROADCASTING)
       {
           if(msgCounter == num_msg_to_send)//send termination msg
           {
                lossyCom__broadcastMsg(&comEndpoint, TERMINATE);
                break;
           }else{ //generate and send random message
                params->receivedACKs = 0;
                lossyCom__broadcastMsg(&comEndpoint, BROADCAST);
                global_broadcast_counter++;
                if(global_broadcast_counter % 1000 == 0){
                    printf("broadcast count: %d\n", global_broadcast_counter);
                } 
                params->broadcasterStatus = RECEIVING_ACK; //mark msg as send
                msgCounter++;
           }
       }

       //check if the benchmark should terminate
       if(params->terminate)
           break;
       
       //receive msg
       lossyCom__receiveMsg(&comEndpoint);
   }
     
  
#ifdef MEASURE_PERF
          saveCyclesAndInstrs(cpuid,endWorkload.cycles);
          numCycles = endWorkload.cycles - startWorkload.cycles;
          //sanity check (400K is about 20K iters)
          if( numCycles < 400000 ) {totalWorkCycles += numCycles; numGoodTasks++;}
          else                     {totalBadCycles  += numCycles; }
#endif

   //wait for all threads to finish
   barrier_wait(params->barrier, animatingPr);

   params->totalWorkCycles = totalWorkCycles;
   params->totalBadCycles = totalBadCycles;
   params->numGoodTasks   = numGoodTasks;
   params->totalSyncCycles = totalSyncCycles;
   params->totalBadSyncCycles = totalBadSyncCycles;
   params->numGoodSyncs = numGoodSyncs;
/*
   params->totalSyncCycles = VMS__give_num_plugin_cycles();
   params->totalBadSyncCycles = 0;
   params->numGoodSyncs = VMS__give_num_plugin_animations();
*/
   //Shutdown worker
   VPThread__dissipate_thread(animatingPr);
   
     //below return never reached --> there for gcc
   return (workspace1 + workspace2);  //to prevent gcc from optimizing work out
 }


/* this is run after the VMS is set up*/
void benchmark(void *_params, VirtProcr *animatingPr)
 {
   int i;
   struct barrier_t barr;
   BenchParams      *params;
   
   params = (BenchParams *)_params;
   
   barrier_init(&barr, NUM_CORES+1, animatingPr);
   
   //Init central communication exchange
   lossyCom__exchange_t* centralMsgExchange = lossyCom__initialize(NUM_CORES);
   
   //prepare input
   for(i=0; i<NUM_CORES; i++)
    { 
       workerParamsArray[i].barrier = &barr;
       workerParamsArray[i].coreID = i;
       workerParamsArray[i].centralMsgExchange = centralMsgExchange;
       workerParamsArray[i].terminate = FALSE;
    }
   global_broadcast_counter = 0;
   
   // init random number generator for wait and msg content
   seed1 = rand()%1000;
   seed2 = rand()%1000;
   
#ifdef MEASURE_PERF     
   //save cycles before execution of threads, to get total exe cycles
   measurement_t *startExeCycles, *endExeCycles;
   startExeCycles = params->startExeCycles;
   

   int nread = read(cycles_counter_main_fd, &(startExeCycles->cycles),
                sizeof(startExeCycles->cycles));
   if(nread<0) perror("Error reading cycles counter");
#endif
   
   //create (which starts running) all threads
   for(i=NUM_CORES-1; i>=0; i--)
    { 
       VPThread__create_thread_with_affinity((VirtProcrFnPtr)worker_TLF,
                                             &(workerParamsArray[i]),
                                             animatingPr,
                                             i);//schedule to core i
    }
  
#ifdef MEASURE_PERF
   //endBarrierCycles read in barrier_wait()!  Merten, email me if want to chg
   params->endExeCycles->cycles = barr.endBarrierCycles.cycles;
#endif
   
   barrier_wait(&barr, animatingPr);
   printf("Total broadcast count: %d\n", global_broadcast_counter);
   
   //print send msgs
   /*
   printf("sendMsgs = []\n");
   for(i = 0; i<NUM_CORES; i++)
   {
       printf("sendMsgs.append([");
       for(idx = 0; idx< workerParamsArray[i].sendMsgs->numInArray; idx++)
       {
           printf("(%lu, %lu),", 
                   (uint64_t)(workerParamsArray[i].ptrToArrayOfSendMsgs[idx]) & 0xFFFFFFFF,
                   ((uint64_t)(workerParamsArray[i].ptrToArrayOfSendMsgs[idx]) >> 32 ) & 0xFFFFFFFF);
       }
        printf("])\n");
   }
   
   
   //print received msgs
   printf("receivedMsgs = []\n");
   for(i = 0; i<NUM_CORES; i++)
   {
       printf("receivedMsgs.append([");
       for(idx = 0; idx< workerParamsArray[i].receivedMsgs->numInArray; idx++)
       {
           printf("(%lu, %lu),", 
                   (uint64_t)(workerParamsArray[i].ptrToArrayOfReceivedMsgs[idx]) & 0xFFFFFFFF,
                   ((uint64_t)(workerParamsArray[i].ptrToArrayOfReceivedMsgs[idx]) >> 32 ) & 0xFFFFFFFF);
       }
       printf("])\n");
   }*/

/*
   uint64_t overallWorkCycles = 0;
   for(i=0; i<num_threads; i++){ 
       printf("WorkCycles: %lu\n",input[i].totalWorkCycles);
       overallWorkCycles += input[i].totalWorkCycles;
    }
   
   printf("Sum across threads of work cycles: %lu\n", overallWorkCycles);
   printf("Total Execution: %lu\n", endBenchTime.cycles-startBenchTime.cycles);
   printf("Runtime/Workcycle Ratio %lu\n", 
   ((endBenchTime.cycles-startBenchTime.cycles)*100)/overallWorkCycles);
*/

   //======================================================

   VPThread__dissipate_thread(animatingPr);
 }

int main(int argc, char **argv)
 {
   int i;

   //set global static variables, based on cmd-line args
   for(i=1; i<argc; i++)
    {
      if(argv[i][0] == '-' && argv[i][2] == 0)
       {
         switch(argv[i][1])
          {
            case 'n':
               if(!isdigit(argv[++i][0]))
                {
                  fprintf(stderr, "-t must be followed by the number messages to send per core\n");
                  return EXIT_FAILURE;
                }
               num_msg_to_send = atoi(argv[i]);
               if(!num_msg_to_send)
                {
                  fprintf(stderr, "invalid number of messages to send: %d\n", num_msg_to_send);
                  return EXIT_FAILURE;
                }
            break;
            case 'h':
               fputs(usage, stdout);
               return 0;		
            default:
               fprintf(stderr, "unrecognized argument: %s\n", argv[i]);
               fputs(usage, stderr);
               return EXIT_FAILURE;
          }//switch
       }//if arg
      else
       {
		fprintf(stderr, "unrecognized argument: %s\n", argv[i]);
		fputs(usage, stderr);
		return EXIT_FAILURE;
       }
    }//for
   
   
#ifdef MEASURE_PERF
   //setup performance counters
    hw_event = malloc(sizeof(struct perf_event_attr));
    memset(hw_event,0,sizeof(struct perf_event_attr));
    
    hw_event->type = PERF_TYPE_HARDWARE;
    hw_event->size = sizeof(hw_event);
    hw_event->disabled = 0;
    hw_event->freq = 0;
    hw_event->inherit = 1; /* children inherit it   */
    hw_event->pinned = 1; /* says this virt counter must always be on HW */
    hw_event->exclusive = 0; /* only group on PMU     */
    hw_event->exclude_user = 0; /* don't count user      */
    hw_event->exclude_kernel = 1; /* don't count kernel  */
    hw_event->exclude_hv = 1; /* ditto hypervisor      */
    hw_event->exclude_idle = 1; /* don't count when idle */
    hw_event->mmap = 0; /* include mmap data     */
    hw_event->comm = 0; /* include comm data     */

    hw_event->config = PERF_COUNT_HW_CPU_CYCLES; //cycles
    
    int cpuID, retries;

   for( cpuID = 0; cpuID < NUM_CORES; cpuID++ )
    { retries = 0;
      do
       { retries += 1;
         cycles_counter_fd[cpuID] = 
          syscall(__NR_perf_event_open, hw_event,
                  0,//pid_t: 0 is "pid of calling process" 
                  cpuID,//int: cpu, the value returned by "CPUID" instr(?)
                  -1,//int: group_fd, -1 is "leader" or independent
                  0//unsigned long: flags
                 );
       }
      while(cycles_counter_fd[cpuID]<0 && retries < 100);
      if(retries >= 100)
       {
         fprintf(stderr,"On core %d: ",cpuID);
         perror("Failed to open cycles counter");
       }
    }

   //Set up counter to accumulate total cycles to process, across all CPUs

   retries = 0;
   do
    { retries += 1;
      cycles_counter_main_fd = 
       syscall(__NR_perf_event_open, hw_event,
               0,//pid_t: 0 is "pid of calling process" 
               -1,//int: cpu, -1 means accumulate from all cores
               -1,//int: group_fd, -1 is "leader" == independent
               0//unsigned long: flags
              );
    }
   while(cycles_counter_main_fd<0 && retries < 100);
   if(retries >= 100)
    {
      fprintf(stderr,"in main ");
      perror("Failed to open cycles counter");
    }
#endif
   
   measurement_t startExeCycles, endExeCycles;
   BenchParams *benchParams;
   
   benchParams = malloc(sizeof(BenchParams)); 
   
   benchParams->startExeCycles = &startExeCycles;
   benchParams->endExeCycles   = &endExeCycles;
   
   workerParamsArray =  (WorkerParams *)malloc( (NUM_CORES) * sizeof(WorkerParams) );
   if(workerParamsArray == NULL ) printf("error mallocing worker params array\n");
   
 
   //This is the transition to the VMS runtime
   VPThread__create_seed_procr_and_do_work( &benchmark, benchParams );
   
#ifdef MEASURE_PERF
   uint64_t totalWorkCyclesAcrossCores = 0, totalBadCyclesAcrossCores = 0;
   uint64_t totalSyncCyclesAcrossCores = 0, totalBadSyncCyclesAcrossCores = 0;
   for(i=0; i<num_threads; i++){ 
       printf("WorkCycles: %lu\n",workerParamsArray[i].totalWorkCycles);
//       printf("Num Good Tasks: %lu\n",workerParamsArray[i].numGoodTasks);
//       printf("SyncCycles: %lu\n",workerParamsArray[i].totalSyncCycles);
//       printf("Num Good Syncs: %lu\n",workerParamsArray[i].numGoodSyncs);
       totalWorkCyclesAcrossCores += workerParamsArray[i].totalWorkCycles;
       totalBadCyclesAcrossCores  += workerParamsArray[i].totalBadCycles;
       totalSyncCyclesAcrossCores += workerParamsArray[i].totalSyncCycles;
       totalBadSyncCyclesAcrossCores  += workerParamsArray[i].totalBadSyncCycles;
    }

   uint64_t totalExeCycles = endExeCycles.cycles - startExeCycles.cycles;
   totalExeCycles -= totalBadCyclesAcrossCores;
   uint64 totalOverhead = totalExeCycles - totalWorkCyclesAcrossCores;
   int32  numSyncs = outer_iters * num_threads * 2;
   printf("Total Execution Cycles: %lu\n", totalExeCycles);
   printf("Sum across threads of work cycles: %lu\n", totalWorkCyclesAcrossCores);
   printf("Sum across threads of bad work cycles: %lu\n", totalBadCyclesAcrossCores);
//   printf("Sum across threads of Bad Sync cycles: %lu\n", totalBadSyncCyclesAcrossCores);
   printf("Overhead per sync: %f\n", (double)totalOverhead / (double)numSyncs );
   printf("ExeCycles/WorkCycles Ratio %f\n", 
          (double)totalExeCycles / (double)totalWorkCyclesAcrossCores);
#else
   printf("#No measurement done!\n");
#endif
   return 0;
 }
