/*

	This program tests the ability of k-LinReg to find the global optimum over 100 trials. 
	With noiseless data, the optimum is 0, and global optimality can be easily checked by setting 'MSE < 1e-6' as a stopping criterion.

	The output of this program looks like
	
	Expe = 1 / 100...  r=1, t=0.3614
	Expe = 2 / 100...  r=1, t=0.2771
	...
	Expe = 100 / 100...  r=1, t=0.2320
	************************
	 N = 1000, n = 2, p = 50
	 -> tau = 324 
	 -> time = 0.0723 +- 0.0985 (tmax = 0.4954)
	 -> r = 1 +- 1 ( rmax = 3 )
	************************
	
	where the results show the number of restarts 'r' and computing time required on average to reach the global optimum.

*/


#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>
#include <pthread.h>
#include <signal.h>
#include <string.h>
#include <fenv.h>
#include <sys/time.h>

#define INF HUGE_VAL
#define MAXFIRSTLINELENGTH 80


// general STOP command for user (Ctrl+C)
static int STOP = 0;

/*
	Handler of ctrl+C signal
*/
void ctrlc_handler(int sig) {

	printf(" ***  Stopped by user ***\n");
	exit(1);
}

/* Read data from file (and allocate memory)
	returns number of data points
	
	call : n = read_data("data.dat", &d, &X, &Y);
*/
int read_data(const char *filename, int *pd, double **X, double **Y) {
	int n, d,i,j;
	FILE *fd;
	char buf[MAXFIRSTLINELENGTH];
	double *Xtmp,*Ytmp;
		
		
	if((fd = fopen(filename, "r")) == NULL) {
		printf("Wrong filename, abording.\n");
		return 0;
	}
	if (fgets(buf, MAXFIRSTLINELENGTH, fd) == NULL) {
		fclose(fd);
		printf("Error in data file, abording.\n");
		return 0;
	}
	sscanf(buf,"%d%d", &n,&d);
	*pd = d;
	
	Xtmp = (double *)malloc(sizeof(double) * n * d);
	Ytmp = (double *)malloc(sizeof(double) * n);
	
	for(i=0;i<n;i++) {
		for (j=0;j<d;j++)
			if (fscanf(fd,"%lf", &Xtmp[i*d+j]) == 0) {
				printf("Error found in file %s\n",filename); 
				fclose(fd);
				*X = Xtmp;
				*Y = Ytmp;
				return i;
			}
		if(fscanf(fd,"%lf", &Ytmp[i]) == 0) {
				printf("Error found in file %s\n",filename); 
				fclose(fd);
				*X = Xtmp;
				*Y = Ytmp;
				return i;
		}
					
	}
	fclose(fd);
	*X = Xtmp;
	*Y = Ytmp;
	return n;
}

/*
	Draw a random initialization for parameter vectors
	
*/
void init_w(const int n, const int p, double *w0, int *n_w0) {
	unsigned int j,k,l;
	
	double w_range = 100; // w0 is within [-w_range, w_range]
	double *w = (double*)malloc(sizeof(double) * n * p);

	for(j=0;j<n*p;j++) 
		w[j] = w_range * (2*drand48() - 1); 
		
	// add new w0 to the list
	for(j=0;j<n*p;j++) {
		w0[(*n_w0) * n * p + j] = w[j];			
	}
	*n_w0 = *n_w0 + 1;
	
	free(w);
	return;	
}

/*
	Define data for working thread
*/
struct th_data { 
	int run;
	double *X;
	double *y;
	int N;
	int n;
	int p;

	double *w0; 	// initial vector
	double *w; 	// local minimizer on output
	double *fval;  	// local minimum
	
	// general settings:	
	int itermax;
	int verbose;
};

static pthread_mutex_t thread_data_mutex = PTHREAD_MUTEX_INITIALIZER;

/*
	k-linreg working thread
	
*/
void *klinreg_thread(void *th_data) {

	unsigned int i,j,k,l;

	// Recover thread parameters
	struct th_data *data = (struct th_data *)th_data;
	double *X = data->X;
	double *y = data->y;
	int N = data->N;
	double *wt = data->w;
	double *fval = data->fval;
	int n = data->n;
	int p = data->p;
	int itermax = data->itermax;
	int verbose = data->verbose;

	int run = data->run;
	
	// copy w0 -> wt
	for(j=0;j<n;j++)
		for(k=0;k<p;k++)
			wt[j*p+k] = data->w0[j*p+k];
	
	// Release mutex on ub_data (end of critical section that launches threads)
	pthread_mutex_unlock(&thread_data_mutex);


	// local variables
	double eij,min_e,sum_e,sum_e_prec,stopping_criterion; 
	int iter;
	double epsilon = 0.0;
	double eps_f = 0.00;
	double *w = (double *)malloc(sizeof(double) * n * p); // parameter vectors
	int *S = (int*)malloc(sizeof(int) * N);
	
	double *Xiw = (double *)malloc(sizeof(double) * n); 

	int *Ni = (int*)malloc(sizeof(int) * n);

	
	double *WORK_= NULL;
	double *SingularValues = (double *)malloc(sizeof(double) * p);
	
	/*
		LIWORK >= max(1, 3 * MINMN * NLVL + 11 * MINMN),
	       where MINMN = MIN( M,N )  = p 
		
		 NLVL = MAX( 0, INT( LOG_2( MIN( M,N )/(SMLSIZ+1) ) ) + 1 )
		SMLSIZ is returned by ILAENV and is equal to the maximum
	         size of the subproblems at the bottom of the computation
		 tree (usually about 25)

	*/			
	int LIWORK = 3 * p * (1 + log2((double)p/(double)(2+1))) + 11 * p;
	int *IWORK_= (int*)malloc(sizeof(int) *  LIWORK);

	int LWORK_prec = 0;
	
	double *A = (double *)malloc(sizeof(double) * N*n*p);
	if(A == NULL) {
		printf("ERROR allocating A\n");
		exit(0);
	}
	double *B = (double *)malloc(sizeof(double) * N*n);
	if(A == NULL) {
		printf("ERROR allocating B\n");
		exit(0);
	}	
		
	// inner loop
	sum_e = INF;
	iter = 0;
	do {
		iter++;
		// compute S_i : assignment of point i
		sum_e_prec = sum_e;
		sum_e = 0;
		for(i=0;i<N;i++) {
			min_e = INF;
			for(j=0;j<n;j++) {				
				Xiw[j] = X[i*p] * wt[j*p];	
				for(k=1;k<p;k++) 			
					Xiw[j] += X[i*p+k] * wt[j*p+k];  	// jth model output
				eij = y[i] - Xiw[j];
				eij *= eij; 			// squared error
				if (eij < min_e) {
					S[i] = j;
					min_e = eij;		// minimum of error
				}
			}
			sum_e += min_e;
			
		}	
		if(verbose>1)
			printf("f = %lf\n",sum_e);


		for(j=0;j<n;j++) 
			Ni[j] = 0;
		
		// compute new w
		for(i=0;i<N;i++) {
			for(j=0;j<n;j++) {				
				if(S[i] == j) {
					// A should be column major
					for(k=0;k<p;k++)
						A[j * p * N + k*N + Ni[j]] = X[i*p + k];
//						A[j * p * N + Ni[j]*p + k] = X[i*p + k];
					B[j * N + Ni[j]] = y[i];
					Ni[j]++;
				}
			}
		}
		// update parameters	
		stopping_criterion = 0;			
		for(j=0;j<n;j++) {

		    if(Ni[j] > p) {	// Stop if a mode has less than p points
		    
			for(k=0;k<p;k++) {
				// save old w
				w[j*p+k] = wt[j*p+k];
			 }
			 
			
			/*
			 Solve the least squares problem
			 
			  min ||B - Ax||_2
			
			A : M x N matrix
			B : M x NRHS matrix 
			
			with LAPACK routine DGELSD

			*/			
			
			int info;
			char TRANS_ = 'N';
			int NRHS_ = 1;

			int LWORK_ = -1; // query optimal LWORK size
			
			double LWORK_opt;	
			
			int M_,N_,LDA_,LDB_,LIWORK;
			M_ = Ni[j];
			N_ = p;
			LDA_ = N; 
			LDB_ = M_;
			
			double RANK_; // output by lapack
			double RCOND_ = -1; // negative for machine precision
			
			// query memory requirements
			dgelsd_(&M_,&N_, &NRHS_, &A[j*p*N], &LDA_, &B[j*N], &LDB_, SingularValues, &RCOND_, &RANK_, &LWORK_opt, &LWORK_, &LIWORK, &info );
			LWORK_ = (int)LWORK_opt;
			
			if(LWORK_ > LWORK_prec) {
				WORK_ = (double*)realloc(WORK_, sizeof(double) * LWORK_);	
			}
			LWORK_prec = LWORK_;
			
			//printf("solve LS: LWORK=%d .WORK=%p, LIWORK=%d, IWORK=%p, info = %d with j=%d, Ni = %d\n",LWORK_,WORK_,LIWORK,IWORK_,info,j,Ni[j]);
			
			// Solve 
			dgelsd_(&M_,&N_, &NRHS_, &A[j*p*N], &LDA_, &B[j*N], &LDB_, SingularValues, &RCOND_, &RANK_, WORK_, &LWORK_, IWORK_, &info );
			
			if (info != 0)
				printf("Warning in dgelsd (least squares solution for X_%d): info = %d\n",j,info);

			// Convergence test
			for(k=0;k<p;k++) {
				wt[j*p+k] = B[j*N+k];
				
				if(verbose>1)
					printf("%lf, ",wt[j*p+k]);
				// test convergence with ||wt - w||^2 < eps
				stopping_criterion += (wt[j*p + k]  - w[j*p + k]) * (wt[j*p + k] - w[j*p + k]);
			}
		}	   		
		else {
			stopping_criterion = -1;
			break;
		}	
	    }
	} while(stopping_criterion > epsilon && (sum_e_prec - sum_e) > eps_f && iter < itermax && !STOP);


	// compute cost function value at local solution
	if(stopping_criterion<0)
		*fval = INF;
	else {
		sum_e = 0;
		for(i=0;i<N;i++) {
			min_e = INF;
			for(j=0;j<n;j++) {
				eij = y[i];				// target output
				for(k=0;k<p;k++) 			// minus
					eij -= X[i*p+k] * wt[j*p+k];  	// jth model output	
				eij *= eij; 			// squared error
				if (eij < min_e) {
					min_e = eij;		// minimum of error
				}
			}
			sum_e += min_e; 
		}
		*fval = sum_e/N;
	}
	
	if(verbose) { 
		
		if (stopping_criterion < 0)
			printf(" %d\t %d  \t (dropped one mode)\t\t %e\n",run, iter, *fval);			
		else if (stopping_criterion <= epsilon)
			printf(" %d\t %d  \t (||wt+1 - wt||^2 <= %1.1e )\t %e\n",run, iter, epsilon, *fval);
		else if ( iter >= itermax)
			printf(" %d\t %d  \t (max nb of iterations)\t\t %e\n",run, iter, *fval);			
		else
			printf(" %d\t %d  \t (cost_t - cost_t+1 < %lf)\t %e\n",run, iter, eps_f, *fval);
	}
	
	
	free(w);
	free(A);
	free(B);
	free(WORK_);
	
	free(SingularValues);
	free(IWORK_);
	
	free(Xiw);
	free(S);
	free(Ni);
	pthread_exit(EXIT_SUCCESS);
}

/*
	Parallel k-LinReg algorithm

	returns best cost function value
*/
double klinreg_global(double *X, double *y, double *wb, const int N, const int n, const int p, int *restarts_, int *bestrun_, const double fgoal, const int itermax, const int verbose) {

	unsigned int i,j,k,l;
	int restarts = *restarts_;
	
	time_t cputime,starttime = time(NULL);
	
	if(restarts < 1)
		restarts = 10000;
	
	double fbest = INF;
	int run = 0;
	int t;

	int kfold = 10;

	double *w0 = (double*)malloc(sizeof(double) * n * p * restarts); // keep list of w0
	if( w0 == NULL ) {
		printf("Not enough memory for %d initializations in dimension %d.\n",restarts,n*p);
		exit(1);
	}
	double *w = (double*)malloc(sizeof(double) * n * p * restarts); // list of local minimizers
	if( w == NULL ) {
		printf("Not enough memory for %d minimizers in dimension %d.\n",restarts,n*p);
		exit(1);
	}
	
	int n_w0 = 0;
	
	pthread_t *threads = (pthread_t*)malloc(sizeof(pthread_t) * restarts); // threads id
	pthread_t *threads_kfold = (pthread_t*)malloc(sizeof(pthread_t) * kfold); // threads id
	void *status; // for pthread_join
	int rc; // return code of pthread functions
		
	struct th_data thread_data, thread_data_kfold; 
	size_t stacksize = 100000000;
	pthread_attr_t attr;
	pthread_attr_init(&attr);
   	pthread_attr_setstacksize (&attr, stacksize);

	int r,rundone,bestrun;
	double *fval = (double*)malloc(sizeof(double) * restarts); // stores the local minima
	
	// prepare common data for all threads
	thread_data.X = X;
	thread_data.y = y;						
	thread_data.N = N;		
	thread_data.n = n; 
	thread_data.p = p;
	thread_data.itermax = itermax;
	if(verbose>0)
		thread_data.verbose = verbose-1;
	else
		thread_data.verbose = 0;
		 
	
	// Init display
	
	if(verbose > 1) {
		printf("\n run   \t #iterations  (convergence) \t\t cost \n");
		printf("--------------------------------------------------------------\n");
		
	}
	else if(verbose == 1) {
		printf("\n run   \t cost   \n");
		printf("-----------------------------------------------------\n");
		
	}

	// initialize random number generator

	srand48(time(NULL));
	rundone = 0;
	do { 
		// Launch NUMTHREADS threads with different initializations
		do {
		
			// random initialization of the parameters		
			init_w(n, p, w0, &n_w0);
			
			// Enters critical section (and wait for previous thread to finish reading parameters)
			pthread_mutex_lock(&thread_data_mutex); 
	
			// set parameters for the thread
			thread_data.run = run; // #thread
			thread_data.w0 = &w0[n * p * run]; // initial vector
			thread_data.w = &w[n * p * run]; // local minimizer on output 	
			thread_data.fval = &fval[run]; // local minimum for this initialization
		
			// launch thread
			rc = pthread_create(&threads[run++], &attr, klinreg_thread, (void *) &thread_data);	
			if(rc != 0) {
				printf("Cannot create thread: probably due to a lack of memory, try to reduce stacksize.\n");
				exit(1);
			}
		} while((run-rundone) < NUMTHREADS && run < restarts && STOP!=1 ); 
	
		// Handle the local minimum and minimizers found so far
		for(r=rundone;r<run;r++) {

			// Wait for threads to terminate
			rc = pthread_join(threads[r],&status);
							
			//  Look for best minimum so far
			if(fval[r] < fbest) {
				fbest = fval[r];
				bestrun = r;		
				cputime = time(NULL) - starttime;
				
				// Copy best minimizer to wb
				for(k=0;k < n*p ; k++) // copy all wj in one shot
					wb[k] = w[n*p*bestrun + k];
					
					
				// Test TOLmse
				if(fbest <= fgoal) {
					STOP = 1;							
				}

				if(verbose>1) {
					printf("    *** new best cost function value = %e (goal is %1.2e) *** \n",fbest,fgoal);	
				}
				else if(verbose) {
					printf("\r  #%d\t %e (goal is %1.2e) \n",bestrun,fbest,fgoal);
				}							
			}
		}
		rundone = run;  
		
		if(verbose == 1 && run%10 == 0){
			printf(".");
			fflush(stdout);
		}
		
		// Loop until all restarts are done
	} while (run < restarts &&  STOP != 1);
	
	if(verbose)
		printf("\nDone. %d / %d initializations tried. \nBest minimum found after %d restarts in %u seconds. \n",n_w0,*restarts_,bestrun,(unsigned int)cputime);
	
	*restarts_ = n_w0; // number of restarts really performed
	*bestrun_ = bestrun; // index of run that led to the best minimum
	
	// Copy best minimizer to wb
	for(k=0;k < n*p ; k++) // copy all wj in one shot
		wb[k] = w[n*p*bestrun + k];
				
	free(w);free(w0);free(fval);free(threads);
	
	return fbest;
		
}
/*
	Help
*/
void print_help(void) {
	printf(" \n k-LinReg algorithm for switched linear regression \n");
	printf("  -- a parallel implementation by F. Lauer (2011)\n\n");
	printf(" This is a test program that evaluates the number of restarts\n");	
	printf(" required to reach the global optimum on noiseless data.\n\n");
	printf("Usage:\n");
	printf("\t test_global N n p\n\n");
	
	return;
}

/*
	main function
*/
int main(int argc, char *argv[]) {
	
	// config options
	int verbose = 1;
	int itermax,restarts;
	double fgoal = 0.0;
	unsigned int maxtime = 0;
	
	// other variables
	unsigned int i,j,k;
	double *X, *y,fbest,yhat,yb,eij,min_e,mean_e; 
	int N,n,p, S,label,classif_e;
	int tau,T;
	double K;

	int dynamic = 0;
	
	double Pfail = 0.001;

	double proba_good;
		
	char gendata_cmd[80];
	float tmax=0, tmean=0, tstd=0, *temps;
	int bestrun, rmax = 0, rmean = 0, rstd=0, *r;
	int expe, Nexpe = 100;
	struct timeval timeofdaySTART, timeofdayEND, elapsedtime;
	
	temps = (float *)malloc(sizeof(float) * Nexpe);
	r = (int *)malloc(sizeof(int) * Nexpe);
		
	if(argc<4) {
				
		print_help();
		exit(0);
	}
	else  {
		// set options
		N = atoi(argv[1]);
		n = atoi(argv[2]);
		p = atoi(argv[3]);
		
		itermax = 500; 
		
		fgoal = 1e-6;
	
	}
		
	 
	
	printf("Starting to optimize %d models with %d points in R^%d\n (searching for mean squared error of %1.2e)\n",n,N,p,fgoal);
	
		
	/* 
	 Compute the probability of success
	 
		Psuccess = K(1 - exp(-(N-tau) / T) )
	
	 where 
		K = 0.99; 
		T = 2.22 * 2^n * p;
		tau = (2^n * p)^(1.4) * 0.2;
		
	or, for dynamical systems (with -d flag) 
		K = 1.02 - 0.023 * n;
		T = 52 * sqrt(2^n * p ) - 220;
		tau = 1.93 * 2^n * p - 37;
	*/
	
	tau = p;
	for (i=0;i<n;i++)
		tau *= 2;	// tau = 2^n * p
	
	if(dynamic) {
		K = 1.02 - 0.023 * n;	
		T = (int)round(52 * sqrt((double)tau) - 220);
		tau = (int)round(1.93 * tau - 37);
	}
	else {
		K = 0.99; 
		T = (int)round(2.22 * tau);
		tau = (int)round(0.2 * pow((double)tau,1.4));
	}

	proba_good = K*(1 - exp(-(N-tau) / T) );

	if(proba_good <= 0)
		proba_good = 0;
		
	printf("Estimated probability of success = %e \n (N=%d ; tau = %d,  T = %d )\n", proba_good,N,tau,T);
	
	
		
	double *w = (double *)malloc(sizeof(double) * n * p); // parameter vectors
	
	
	// Install handler of SIGINT (ctrl+C) signal
	struct sigaction newhandler,oldhandler;
	memset(&newhandler, 0, sizeof(newhandler));
	newhandler.sa_handler = &ctrlc_handler;
	sigaction(SIGINT, &newhandler, &oldhandler);

	for (expe=0;expe<Nexpe;expe++) {
		
		printf("Expe = %d / %d... ",expe+1,Nexpe);
		fflush(stdout);
		
		sprintf(gendata_cmd, "./gendata data.dat %d %d %d 0. > /dev/null", N, n, p);
		system(gendata_cmd);

		// READ DATA FILE
		N = read_data("data.dat", &p, &X, &y);
		if (N == 0) {
			printf("No data.\n");
			return 1;
		}
		
		gettimeofday(&timeofdaySTART, NULL);
		
		restarts = 99;
		fbest = klinreg_global(X,y,w,N,n,p,&restarts,&bestrun, fgoal, itermax,0);
		STOP = 0;
		free(X);free(y);
		
		gettimeofday(&timeofdayEND, NULL);
		
		timersub(&timeofdayEND, &timeofdaySTART, &elapsedtime);
		
		r[expe] = bestrun+1;
		temps[expe] = (float)elapsedtime.tv_sec + 1e-6 * ((float)elapsedtime.tv_usec);
		
		rmean += r[expe];
		tmean += temps[expe];
		
		if(fbest > fgoal) {
			r[expe] = 100;
			rmax = 100;
			tmax = temps[expe];
			printf("******** FAILED *********\n");
			printf(" N = %d, n = %d, p = %d\n", N,n,p);
			printf(" -> tau = %d\n",tau);
			printf(" -> time > %3.4f\n", tmax);
			printf(" -> r > %d \n\n",rmax);
			printf("************************\n");
			free(w);
			free(temps);free(r);
			return 0;

		}
		if(r[expe] > rmax)
			rmax = r[expe];
		if(temps[expe] > tmax)
			tmax = temps[expe];
		
		printf(" r=%d, t=%3.4f\n",r[expe],temps[expe]);
	}

	rmean /= expe;
	tmean /= (float)expe;
	
	for(expe = 0; expe<Nexpe;expe++) {
		rstd += (r[expe] - rmean) * (r[expe] - rmean);
		tstd += (temps[expe] - tmean) * (temps[expe] - tmean);	
	}	
	if (rstd != 0)
		rstd = (int)ceil( sqrt( (float)rstd / (float)(expe-1) ) ) ;
		
	tstd /= (float)(expe-1);
	tstd = sqrt(tstd); 
	
	printf("************************\n");
	printf(" N = %d, n = %d, p = %d\n", N,n,p);
	printf(" -> tau = %d\n",tau); 
	printf(" -> time = %3.4f +- %3.4f (tmax = %3.4f)\n", tmean,tstd,tmax);
	printf(" -> r = %d +- %d ( rmax = %d ) \n\n",rmean, rstd,rmax);
	printf("************************\n");
	
	free(w);
	free(temps);free(r);
	return 0;
}
