/*  Updated 9/18/2000.  Earlier version had a serious bug. */
/*  Modified 9/19/2000 to allow more than two samples, and carry out 
     global test and test each pair of samples.              */
/*    Modified 3/9/2001 added (1-eps) factor to make sure round-off error
	doesn't cause problems.  */

/*
This program carries out the "nearest neighbor test" described in 
Hudson, R.R. 2000  "A new statistic for detecting genetic differentiation"
Genetics (in press).
Data consists of a matrix of pairwise differences between sampled sequences,
as described in readme. An example data file is sample.dat .  The program
can usually be compiled by gcc -o snn snn.c  . And then run by 

snn y 1000 20 27 <sample.dat >sample.out   .

On the command line,  20 is the sample size from
locality one, and 27 is the sample size from locality two.  1000 is
the number of permutations to carry out.  y indicates that the input
data matrix should be copied to the output.
  If the first sample, in this example, were actually from two localities, 
say the first 8 from one locality and the next 12 from another locality,
then one can run the program as follows:  
snn n 1000 8 12 27 <sample.dat >sample2.out  
The program will carry out a global test for structure, and will test each
pair of samples for significant structure.
*/


#include <stdio.h>
#include <math.h>
#include <stddef.h>
#include <stdlib.h>
#include <ctype.h>




main( int argc, char * argv[] )
{
	int nperms, nlocs,  nsam ;
	double **dij, **dij2, pval, snno,   permt(), **tdij ;
	int i, loc, wf,  *ni, tni[3], loc1, loc2, j, sti, stj, endi, endj ;
	int iloc, jloc, ni2[2] ;
	 
	if( argc < 4 ) {
	  printf(
 "usage: snn  y/n(print data?) n_permutations n1 n2  ... \n"); 
		exit(1) ;
		}
	nperms = atoi( argv[2] ) ;
	nlocs = argc - 3  ;
	if( (ni = (int *)malloc( (size_t) (nlocs+1)*sizeof(int) )) == NULL)
		perror( "malloc error1\n") ;
	nsam = 0 ;
	for( i=0; i<nlocs; i++){
	   ni[i] = atoi( argv[3+i] );
	   nsam += ni[i];
	   }

	if( (dij = (double **)malloc((size_t)nsam*sizeof( double * ) ) ) == NULL)
		perror( "malloc error2\n") ;
	for(i=0;i<nsam;i++){
		if( (dij[i] = (double *)malloc( (size_t)nsam*sizeof(double))) == NULL)
			perror( "malloc error3\n") ;
	   }
	if( (dij2 = (double **)malloc((size_t)nsam*sizeof( double * ) ) ) == NULL)
		perror( "malloc error2\n") ;
	for(i=0;i<nsam;i++){
		if( (dij2[i] = (double *)malloc( (size_t)nsam*sizeof(double))) == NULL)
			perror( "malloc error3\n") ;
	   }
  

	getmatrix(nsam, dij, argv[1][0] );
	
	fprintf(stdout,"\nSample configuration: ");
	loc=0;
	while( loc<nlocs) fprintf(stdout,"%d  ",ni[loc++]);

	fprintf(stdout,"\nNumber of permutations: %d  \n",nperms);

	pval = permt(nperms,nsam,nlocs,ni,dij, &snno ) ;

	if( nlocs > 2 ) fprintf(stdout," Global test:\n     ");
	fprintf(stdout," Snn: ");
	fprintf(stdout,"  %lf ( p-value: %lf)\n", snno, pval);
	if( nlocs > 2 ) {
	  fprintf(stdout," Pairwise tests of samples:\n");
	  for( iloc = 0; iloc < nlocs-1; iloc++) 
	   for( jloc = iloc + 1 ; jloc < nlocs; jloc++){
	     loaddij(  iloc, jloc, ni, dij, dij2 ) ;
	     ni2[0] = ni[iloc]; ni2[1] = ni[jloc] ;
	     pval = permt( nperms, ni2[0]+ni2[1], 2, ni2, dij2, &snno);
	     fprintf(stdout,"     %d %d:   Snn: ", 
		iloc+1,  jloc+1 );
	     fprintf(stdout,"  %lf ( p-value: %lf)\n", snno, pval);
	     }
          }
}

	int
loaddij( int iloc,  int jloc, int *ni, double **dij, double **dij2 )
{
	int start1, start2, i, j ;
	int ni2[2] ;

	start1 = start2 = 0 ;
	for( i=0; i< iloc; i++) start1 += ni[i] ;
	for( i=0; i< jloc; i++) start2 += ni[i] ;
	ni2[0] = ni[iloc]; ni2[1] = ni[jloc] ;
	for( i=0; i<ni2[0] ; i++){
	  dij2[i][i] = 0.0 ;
	  for( j= i+1; j< ni2[0]; j++){
	    dij2[i][j] = dij[start1+i][start1+j] ;
	    dij2[j][i] = dij2[i][j] ;
	    }
	  for( j = ni2[0]; j< ( ni2[0]+ni2[1]) ; j++) {
		dij2[i][j] = dij[start1+i][start2+j-ni2[0]] ;
		dij2[j][i] = dij2[i][j] ;
		} 
	  }
	for( i=ni2[0] ; i<(ni2[1]+ni2[0]) ; i++){
	  dij2[i][i] = 0.0 ;
	  for( j= i+1; j< (ni2[0]+ni2[1]) ; j++){
	    dij2[i][j] = dij[start2+i-ni2[0] ][start2+j-ni2[0] ] ;
	    dij2[j][i] = dij2[i][j] ;
	    }
	  }
}
	
	double
permt(nperms,nsam,nlocs,ni,dij,psnno )
   int nperms,nlocs,nsam ;
   int *ni ;
   double **dij, *psnno ;
  {
  	int i, count, count2,  loc, start ;
  	double  eps = 1.0e-8, snno, snnf(), snn  ;
  	int *ran_ind ;

	if( (ran_ind = (int *) malloc( (size_t)nsam*sizeof(int) )) == NULL)
			perror( "malloc error9\n") ;

	for(i=0;i<nsam;i++)  ran_ind[i]=i;
	snno = *psnno = snnf( nsam, nlocs, ni,dij,ran_ind);
	count = count2 =  0 ;	
while (nperms - count++) {
		scramb(ran_ind,nsam);
		
	snn  = snnf( nsam, nlocs, ni,dij,ran_ind);

	if( snn >= snno*(1.0-eps) ) count2++ ;

	}
   free( ran_ind);
   return( (double)count2/(double)nperms );	
	
}


	int	   
getmatrix(nsam,dij,pflag)
	int nsam;
	double **dij;
	char pflag;
{
	int i, j, jstart, jend ;
	char gamnam[35], s[1000];
	FILE *pf ;
	
	pf = stdin; 

	i = 0 ;
	jend = -1 ;
	while( i< nsam-1 ){
	  jstart = jend +1  ;
	  while(1){
		fgets(s,1000, pf);
		if( pflag == 'y') fputs(s,stdout);
		if( s[0]=='O' && s[1]=='T' && s[2]=='U'&&s[3]=='s' ) {
			jend = countwords(s) + jstart -2 ;
			if( i==0) jstart++ ;
			break;
			}
		}	
	for( i=0; i<jend; i++){
		fscanf(pf," %s",gamnam);
		if( pflag == 'y') fprintf(stdout,"%s",gamnam);
		dij[i][i] = 0.0 ;
		if( jstart < i+1) jstart = i+1 ;
		for( j=jstart; j<= jend; j++) {
			fscanf(pf," %lf", dij[i]+j );
			if( pflag == 'y') fprintf(stdout," %g",dij[i][j] );
			 dij[j][i]=dij[i][j];
			 }
		if( pflag == 'y' ) fprintf(stdout,"\n");
		}
	}
}

        int
countwords(s)
        char *s;
{
        int wc, cc=0, out = 1, in=0 , state, c ;

        wc = 0 ;
        state = out;
        while( s[cc] != '\n' ){
                c = s[cc] ;
                if( isspace( c) ) state = out ;
                else if( state == out) {
                        state = in ;
                        wc++;
                        }
                cc++;
                }
                return( wc) ;
}

	int
scramb(ran_vect, n)
	int *ran_vect, n;
{
	double ran1();
	int i, j, temp ;
	
	for( i = n-1 ; i> 0 ; i--){
		j = ran1()*(i + 1) ;
		temp = ran_vect[j];
		ran_vect[j] = ran_vect[i] ;
		ran_vect[i] = temp ;
		}
}



	double
ran1()
{
  double drand48();
	return( (double)drand48()  );
}


double snnf(int nsam, int npop, int *config, double  **dij, int *ran_ind)
{
        int i2, i, ind,  nmall, nmwithin , pop, indstart ;
        double p, mall, minofdij()  ;

        p = 0. ;

        for(ind= pop=0; pop<npop; pop++)
          for( i=0, indstart= ind ; i<config[pop]; i++,ind++){
                mall = minofdij( nsam, dij[ ran_ind[ind]],ran_ind[ind], &nmall );
                nmwithin = howman(nsam, dij[ ran_ind[ind]],ran_ind, ind,mall, indstart, config[pop] );
                p += (double)nmwithin/(double)nmall ;
                }
         return( p /= nsam ) ;

}

/* counts number of nearest neighbors that are within subpopulation */

int  howman(int nsam, double  *vec, int *ran_ind, int ind, double  mall, int indstart, int nsubpop)
{
        int count, i ;

     count = 0 ;
        for( i=indstart; i<indstart + nsubpop; i++)
                if( (i != ind) && ( vec[ ran_ind[i] ] == mall ) ) count++;
return( count ) ;
}

/* finds minimum distance to neighbor  and counts number of such nearest neighbors */

       double 
minofdij(int nsam, double *vec, int ind, int *pnmall)
{
        int i,  count ;
	double mind ;

        if( ind != 0 ) mind = vec[0] ;
        else mind = vec[1] ;
        count = 0 ;
        for( i=0; i<nsam; i++)
                if( i != ind ){
                         if( vec[i] < mind ) { count =1; mind = vec[i] ;}
                         else if ( vec[i] == mind ) count++ ;
                         }
        *pnmall = count ;
        return( mind ) ;
}




