#include <stdio.h>
#include <math.h>
#include <stddef.h>
#include <stdlib.h>
#include <ctype.h>

	double
chi(nsam,gams,na,freq,nlocs, config, ran_ind, expected)
	int nsam, *gams, na, *freq, *config, *ran_ind, nlocs ;
	double **expected;
{
	int *freq1, i,j,loc, jstart ;
	double sum=0., dif;

	if( na < 2 ) return( 0.0 );	
		freq1 = (int *) malloc( (size_t)na*sizeof(int) );
	  for( loc=jstart=0; loc<nlocs; jstart += config[loc++]) {
		for(i=0;i<na;i++) freq1[i]=0 ;
		for(j=jstart ;j<jstart+config[loc];j++)
			freq1[gams[ran_ind[j]]]++;
		for(i=0;i<na;i++){ 
			dif = freq1[i]-expected[loc][i]  ;
			if( expected[loc][i] > 0.0 ) 
			  sum += (double)(dif*dif)/expected[loc][i] ;
			}
		}
		free( freq1 );
		return(  sum ) ;
}


	double
permchi(nperms,nsam,nlocs,ni,dij,pnum,pht,pchi)
   int nperms,nlocs,nsam, *ni, *pnum ;
   double **dij,  *pht, *pchi ;
  {
  	int i, count, countks, loc, start, zero=0, numalleles ;
  	double ht, hetero(), chio, chir, **expected, chi(), eps = 1.0e-8, calcexp() ;
  	int *ran_ind, *gams, *frequ ;

	if( ( gams = (int *) malloc( (size_t)nsam*sizeof(int) ) ) == NULL) 
 			perror( "malloc error5\n") ;
	if( ( frequ = (int *) malloc( (size_t)nsam*sizeof(int) ) ) == NULL) 
 			perror( "malloc error5\n") ;

	if( (ran_ind = (int *) malloc( (size_t)nsam*sizeof(int) )) == NULL)
			perror( "malloc error9\n") ;

	for(i=0;i<nsam;i++)  ran_ind[i]=i;
	*pnum = numalleles = countalleles(nsam,dij,zero,nsam,frequ,gams);
	ht = hetero(nsam,numalleles,frequ); 
	*pht = ht ;
	if( (expected = (double **)malloc((size_t)nlocs*sizeof( double * ) ) ) == NULL)
		perror( "malloc error2\n") ;
	for(i=0;i<nlocs;i++){
		if( (expected[i] = (double *)malloc( (size_t)numalleles*sizeof(double))) == NULL)
			perror( "malloc error3\n") ;
	   }
	
    calcexp( nsam,nlocs, ni, numalleles, frequ,expected); 
	*pchi = chio =  chi(nsam, gams,numalleles, frequ,nlocs, ni, ran_ind, expected);


	chio -= eps*chio ;
	count = countks = 0 ;	
while (nperms - count++) {
		scramb(ran_ind,nsam);
		chir = chi(nsam, gams,numalleles, frequ,nlocs, ni, ran_ind, expected);
		if( chir >= chio ) countks++ ;

	}
   free( ran_ind);
   free( gams ) ;
   free( frequ ) ;
   for( i=0; i<nlocs; i++) free( expected[i] );
   free ( expected ) ;
   return( (double)countks/(double)nperms );	
	
}
	



main( int argc, char * argv[] )
{
	int nperms, nlocs,  nsam ;
	double **dij, pval, ht, chi,   **tdij ;
	int i, loc,   *ni, numalleles, tni[2], loc1, loc2, j, sti, stj, endi, endj ;


	if( argc < 5 ) {
	   printf("usage: permchi  y/n(print data?)  n_permutations n_locals n1 n2 ... \n"); 
		exit(1) ;
		}
	nperms = atoi( argv[2] ) ;
	nlocs = atoi( argv[3] );
	if( (ni = (int *)malloc( (size_t) (nlocs+1)*sizeof(int) )) == NULL)
		perror( "malloc error1\n") ;
	nsam = 0 ;
	for( i=0; i<nlocs; i++){
	   ni[i] = atoi( argv[4+i] );
	   nsam += ni[i];
	   }

	if( (dij = (double **)malloc((size_t)nsam*sizeof( double * ) ) ) == NULL)
		perror( "malloc error2\n") ;
	for(i=0;i<nsam;i++){
		if( (dij[i] = (double *)malloc( (size_t)nsam*sizeof(double))) == NULL)
			perror( "malloc error3\n") ;
	   }
  

	getmatrix(nsam, dij, argv[1][0] );
	
	fprintf(stdout,"\nSample configuration: ");
	loc=0;
	while( loc<nlocs) fprintf(stdout,"%d  ",ni[loc++]);

	fprintf(stdout,"\nTest of Roff and Bentzen MBE 6: 539-45\n");
	fprintf(stdout,"Number of permutations: %d \n",nperms);

	pval = permchi(nperms,nsam,nlocs,ni,dij, &numalleles,&ht,&chi);

	fprintf(stdout,"Observed values of statistics:\n");
	fprintf(stdout,"Number of alleles: %d. Ht: %lf  Chi: %lf ( p-value: %lf)\n\n",
	   numalleles,ht,chi,pval);
	
	if( nlocs > 2 ){
	   if( (tdij = (double **)malloc((size_t)nsam*sizeof( double * ) ) ) == NULL)
		perror( "malloc error2\n") ;
	   for(i=0;i<nsam;i++){
		if( (tdij[i] = (double *)malloc( (size_t)nsam*sizeof(double))) == NULL)
			perror( "malloc error3\n") ;
	     }
	    for( sti=loc1=0; loc1< nlocs-1; loc1++ ) {
	       endi = sti + ni[loc1] ;
	       for( stj=endi, loc2=loc1+1; loc2<nlocs; loc2++) {
	           endj = stj + ni[loc2] ; 
	           for( i= sti; i<endi-1; i++)
	             for( j= i+1; j<endi; j++)
	               tdij[i-sti][j-sti] =tdij[j-sti][i-sti] = dij[i][j] ;
	           for( i= stj; i<endj-1; i++)
	             for( j= i+1; j<endj; j++)
	              tdij[i-stj+ni[loc1]][j-stj+ni[loc1]] =tdij[j-stj+ni[loc1]][i-stj+ni[loc1]] = dij[i][j] ;
	           for( i= sti; i<endi; i++)
	             for( j= stj; j<endj; j++)
	              tdij[i-sti][j-stj+ni[loc1]] =tdij[j-stj+ni[loc1]][i-sti] = dij[i][j] ;
	       tni[0] = ni[loc1]; tni[1]=ni[loc2] ;
	       pval = permchi(nperms,(int)(ni[loc1]+ni[loc2]),2,tni,tdij,&numalleles,&ht,&chi) ;
	       printf(" %d %d: ",loc1+1,loc2+1);
  fprintf(stdout,"Number of alleles: %d. Ht: %lf  Chi: %lf ( p-value: %lf)\n",
	   numalleles,ht,chi,pval);
	           stj += ni[loc2] ;
	           }
	       sti += ni[loc1] ;
	       }
       }
      
}
	
	double
calcexp( nsam,nlocs, ni, na, frequ, expected)
	int nsam, nlocs, *ni, na, *frequ ;
	double **expected;
{
	int loc, i;
	
	for( loc = 0; loc<nlocs; loc++)
	    for( i=0; i<na; i++)
	    	expected[loc][i] = ni[loc]*((double)frequ[i]/nsam) ;

}
	 


	double
hetero(n,na,freq)
	int n, *freq, na;
{
	int i;
	double het=0.0, nd;
	
	nd = n ;
	for( i=0;i<na;i++) het += freq[i]*freq[i] ;
	return(  (nd/(nd-1.))*(1.0 - het/(nd*nd) )  ) ;
}




	int
countalleles(nsam,dij,start, end,frequ,gams)
	int nsam, *frequ, *gams, start, end;
	double **dij;
{
	int i, j, numalleles, flag;
	
	numalleles = 0 ;
	for( j=start;j<end;j++){
	  for( i=start,flag = 0 ;i<j;i++) if( dij[j][i] == 0 ) { flag = 1; break;}
	  if( !flag ){
	  	frequ[numalleles]=1;
	  	gams[j]=numalleles;
	  	for(i=j+1;i<end;i++) if( dij[j][i] == 0 ){
	  		frequ[numalleles]++;
			gams[i]=numalleles;
			}
	 	numalleles++;
		}
	   }
	  return( numalleles );
}



	int	   
getmatrix(nsam,dij,pflag)
	int nsam;
	double **dij;
	char pflag;
{
	int i, j, jstart, jend ;
	char gamnam[35], s[1000];
	FILE *pf ;
	
	pf = stdin; 

	i = 0 ;
	jend = -1 ;
	while( i< nsam-1 ){
	  jstart = jend +1  ;
	  while(1){
		fgets(s,1000, pf);
		if( pflag == 'y') fputs(s,stdout);
		if( s[0]=='O' && s[1]=='T' && s[2]=='U'&&s[3]=='s' ) {
			jend = countwords(s) + jstart -2 ;
			if( i==0) jstart++ ;
			break;
			}
		}	
	for( i=0; i<jend; i++){
		fscanf(pf," %s",gamnam);
		if( pflag == 'y') fprintf(stdout,"%s",gamnam);
		dij[i][i] = 0.0 ;
		if( jstart < i+1) jstart = i+1 ;
		for( j=jstart; j<= jend; j++) {
			fscanf(pf," %lf", dij[i]+j );
			if( pflag == 'y') fprintf(stdout," %g",dij[i][j] );
			 dij[j][i]=dij[i][j];
			 }
		if( pflag == 'y' ) fprintf(stdout,"\n");
		}
	}
}

	int
countwords(s)
	char *s;
{
	int wc, cc=0, out = 1, in=0 , state, c ;
	
	wc = 0 ;
	state = out;
	while( s[cc] != '\n' ){
		c = s[cc] ;
		if( isspace( c) ) state = out ;
		else if( state == out) {
			state = in ;
			wc++;
			}
		cc++;
		}
		return( wc) ;
}

	int
scramb(ran_vect, n)
	int *ran_vect, n;
{
	double ran1();
	int i, j, temp ;
	
	for( i = n-1 ; i> 0 ; i--){
		j = ran1()*(i + 1) ;
		temp = ran_vect[j];
		ran_vect[j] = ran_vect[i] ;
		ran_vect[i] = temp ;
		}
}



	double
ran1()
{
  double drand48();
	return( drand48()  );
}


