#include <stdio.h>
#include <math.h>
#include <stddef.h>
#include <stdlib.h>
#include <ctype.h>




main( int argc, char * argv[] )
{
	int nperms, nlocs,  nsam ;
	double **dij, pval, kwithin(), ksto,kso,kt,  permt(), **tdij ;
	int i, loc, wf,  *ni, tni[3], loc1, loc2, j, sti, stj, endi, endj ;
	 
	if( argc < 5 ) {
	  printf(
 "usage: perm name-of-statistic weighting-factor y/n(print data?) n_permutations n_locals n1 n2 ... \n"); 
		exit(1) ;
		}
	nperms = atoi( argv[4] ) ;
	wf = atoi( argv[2] ) ;
	nlocs = atoi( argv[5] );
	if( (ni = (int *)malloc( (size_t) (nlocs+1)*sizeof(int) )) == NULL)
		perror( "malloc error1\n") ;
	nsam = 0 ;
	for( i=0; i<nlocs; i++){
	   ni[i] = atoi( argv[6+i] );
	   nsam += ni[i];
	   }

	if( (dij = (double **)malloc((size_t)nsam*sizeof( double * ) ) ) == NULL)
		perror( "malloc error2\n") ;
	for(i=0;i<nsam;i++){
		if( (dij[i] = (double *)malloc( (size_t)nsam*sizeof(double))) == NULL)
			perror( "malloc error3\n") ;
	   }
  

	getmatrix(nsam, dij, argv[3][0] );
	
	fprintf(stdout,"\nSample configuration: ");
	loc=0;
	while( loc<nlocs) fprintf(stdout,"%d  ",ni[loc++]);

	fprintf(stdout,"\nNumber of permutations: %d  weighting constant: %d\n",nperms, wf);

	pval = permt(nperms,wf,nsam,nlocs,ni,dij, &ksto, &kso,&kt) ;

	fprintf(stdout,"Observed values of statistics:\n");
	fprintf(stdout," %sst: %lf , %ss: %lf  %st: %lf ( p-value: %lf)\n\n",
	    argv[1],ksto,argv[1],kso,argv[1],kt,pval);
	
	if( nlocs > 2 ){
	   if( (tdij = (double **)malloc((size_t)nsam*sizeof( double * ) ) ) == NULL)
		perror( "malloc error2\n") ;
	   for(i=0;i<nsam;i++){
		if( (tdij[i] = (double *)malloc( (size_t)nsam*sizeof(double))) == NULL)
			perror( "malloc error3\n") ;
	     }
	    for( sti=loc1=0; loc1< nlocs-1; loc1++ ) {
	       endi = sti + ni[loc1] ;
	       for( stj=endi, loc2=loc1+1; loc2<nlocs; loc2++) {
	           endj = stj + ni[loc2] ; 
	           for( i= sti; i<endi-1; i++)
	             for( j= i+1; j<endi; j++)
	               tdij[i-sti][j-sti] =tdij[j-sti][i-sti] = dij[i][j] ;
	           for( i= stj; i<endj-1; i++)
	             for( j= i+1; j<endj; j++)
	              tdij[i-stj+ni[loc1]][j-stj+ni[loc1]] =tdij[j-stj+ni[loc1]][i-stj+ni[loc1]] = dij[i][j] ;
	           for( i= sti; i<endi; i++)
	             for( j= stj; j<endj; j++)
	              tdij[i-sti][j-stj+ni[loc1]] =tdij[j-stj+ni[loc1]][i-sti] = dij[i][j] ;
	       tni[0] = ni[loc1]; tni[1]=ni[loc2] ;
	       pval = permt(nperms,wf,(int)(ni[loc1]+ni[loc2]),2,tni,tdij, &ksto, &kso,&kt) ;
	       printf(" %d %d: ",loc1+1,loc2+1);
		   fprintf(stdout," %sst: %lf , %ss: %lf  %st: %lf ( p-value: %lf)\n",
	   			 argv[1],ksto,argv[1],kso,argv[1],kt,pval);
	           stj += ni[loc2] ;
	           }
	       sti += ni[loc1] ;
	       }
       }
      
}
	
	double
permt(nperms,wf,nsam,nlocs,ni,dij,pksto, pks, pkt)
   int nperms,nlocs,nsam ;
   int wf,*ni ;
   double **dij, *pksto, *pks, *pkt ;
  {
  	int i, count, countks, loc, start ;
  	double kso, eps = 1.0e-8, *ki, kwithin(), kt, ks ;
  	int *ran_ind ;

	if( (ki = (double *)malloc( (size_t) (nlocs+1)*sizeof(double) )) == NULL)
		perror( "malloc error1\n") ;

	if( (ran_ind = (int *) malloc( (size_t)nsam*sizeof(int) )) == NULL)
			perror( "malloc error9\n") ;

	for(i=0;i<nsam;i++)  ran_ind[i]=i;

	kso = 0.0 ;
	i = 0 ;
	kt = kwithin(i,(int)nsam,dij,ran_ind);
	*pkt = kt ;
	for( loc=0,start=0; loc<nlocs; loc++){
			ki[loc] = kwithin( start, start+ni[loc],dij, ran_ind );
			kso += ki[loc]*(ni[loc]-wf)/((double)nsam-nlocs*wf) ;
			start += ni[loc] ;
			}
	*pks = kso ;
    *pksto = 1 - kso/kt ;

	kso += eps*kso ;
	count = countks = 0 ;	
while (nperms - count++) {
		scramb(ran_ind,nsam);
		ks = 0.0 ;
	    for( loc=0,start=0; loc<nlocs; loc++){
			ki[loc] = kwithin( start, start+ni[loc],dij, ran_ind );
			ks += ki[loc]*(ni[loc]-wf)/((double)nsam-nlocs*wf) ;
			start += ni[loc] ;
			}
		if( ks <= kso ) countks++ ;

	}
   free( ran_ind);
   free( ki ) ;
   return( (double)countks/(double)nperms );	
	
}

	double
kwithin(st,end,dij, ran_ind)
	int st, end ;
	int ran_ind[]; 
	double **dij;
{
	double sum = 0.0 ;
	int i, j;

	for( i=st; i<end-1; i++){ 
	  for( j=i+1; j<end; j++){
	    sum += dij[ran_ind[i]][ran_ind[j]] ;
	    }
	   }
	return(  sum/(  ((double)(end-st))*(end-st-1.0)/2.0 ) ) ;
}




	int	   
getmatrix(nsam,dij,pflag)
	int nsam;
	double **dij;
	char pflag;
{
	int i, j, jstart, jend ;
	char gamnam[35], s[1000];
	FILE *pf ;
	
	pf = stdin; 

	i = 0 ;
	jend = -1 ;
	while( i< nsam-1 ){
	  jstart = jend +1  ;
	  while(1){
		fgets(s,1000, pf);
		if( pflag == 'y') fputs(s,stdout);
		if( s[0]=='O' && s[1]=='T' && s[2]=='U'&&s[3]=='s' ) {
			jend = countwords(s) + jstart -2 ;
			if( i==0) jstart++ ;
			break;
			}
		}	
	for( i=0; i<jend; i++){
		fscanf(pf," %s",gamnam);
		if( pflag == 'y') fprintf(stdout,"%s",gamnam);
		dij[i][i] = 0.0 ;
		if( jstart < i+1) jstart = i+1 ;
		for( j=jstart; j<= jend; j++) {
			fscanf(pf," %lf", dij[i]+j );
			if( pflag == 'y') fprintf(stdout," %g",dij[i][j] );
			 dij[j][i]=dij[i][j];
			 }
		if( pflag == 'y' ) fprintf(stdout,"\n");
		}
	}
}

	int
countwords(s)
	char *s;
{
	int wc, cc=0, out = 1, in=0 , state, c ;
	
	wc = 0 ;
	state = out;
	while( s[cc] != '\n' ){
		c = s[cc] ;
		if( isspace( c) ) state = out ;
		else if( state == out) {
			state = in ;
			wc++;
			}
		cc++;
		}
		return( wc) ;
}

	int
scramb(ran_vect, n)
	int *ran_vect, n;
{
	double ran1();
	int i, j, temp ;
	
	for( i = n-1 ; i> 0 ; i--){
		j = ran1()*(i + 1) ;
		temp = ran_vect[j];
		ran_vect[j] = ran_vect[i] ;
		ran_vect[i] = temp ;
		}
}



	double
ran1()
{
  double drand48();
	return( (double)drand48()  );
}


