/*
A simple 1D FEM mesh preparation and testing tool
*/
#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include "charm++.h"
#include "fem.h"

#ifndef ENABLE_MIG
# define ENABLE_MIG 0
#endif

extern "C" void mesh_updated(int param);

const int np=2; /*Nodes per element*/
const int nSparseT=10; //Number of sparse regions
class myMesh {
public:
	int nelems, nnodes;
	int *conn; //Connectivity: maps each element to its np nodes
	double *nodes; //Node data: just one coordinate per node
	
	int nSparse[nSparseT];
	int *sparse[nSparseT];
	
private:
	void readWriteData(int mesh) {
		FEM_Mesh_data(mesh,FEM_NODE,FEM_DATA+0,
			nodes, 0,nnodes, FEM_DOUBLE,1);
		FEM_Mesh_data(mesh,FEM_ELEM+0,FEM_CONN,
			conn, 0,nelems, FEM_INDEX_0,np);
		for (int s=0;s<nSparseT;s++) {
			FEM_Mesh_data(mesh,FEM_SPARSE+s,FEM_CONN,
				sparse[s], 0,nSparse[s], FEM_INDEX_0,1);
		}
	}
	
	void allocate(int ne,int nn,int *nSparse_) {
		nelems=ne; nnodes=nn;
		conn=new int[np*nelems];
		nodes=new double[nnodes];
		for (int s=0;s<nSparseT;s++) {
			nSparse[s]=nSparse_[s];
			sparse[s]=new int[nSparse[s]];
		}
	}
	void deallocate(void) {
		delete[] conn;
		delete[] nodes;
		for (int s=0;s<nSparseT;s++) 
			delete[] sparse[s];
	}
	
	void writeNoGlobals(int mesh,int entity,int n) {
		int *gno=new int[n];
		for (int i=0;i<n;i++) gno[i]=-1;
		FEM_Mesh_data(mesh,entity,FEM_GLOBALNO,
			gno, 0,n, FEM_INDEX_0,1);
		delete[] gno;
	}
	
public: 
	/// Build this mesh with ne elements and nn nodes:
	myMesh(int ne,int nn,int *nSparse_) {
		allocate(ne,nn,nSparse_);
	}
	/// Build this mesh from this FEM mesh structure:
	myMesh(int mesh) {
		for (int s=0;s<nSparseT;s++) 
			nSparse[s]=FEM_Mesh_get_length(mesh,FEM_SPARSE+s);
		allocate(FEM_Mesh_get_length(mesh,FEM_ELEM+0),
		         FEM_Mesh_get_length(mesh,FEM_NODE),
			 nSparse);
		readWriteData(mesh);
	}
	
	~myMesh() {deallocate();}
	
	/// Write this mesh's data to this FEM mesh data structure:
	void write(int mesh) {
		readWriteData(mesh);
	}
	
	/// Write that we have no global numbers:
	void writeNoGlobals(int mesh) {
		writeNoGlobals(mesh,FEM_NODE,nnodes);
		writeNoGlobals(mesh,FEM_ELEM+0,nelems);
		for (int s=0;s<nSparseT;s++)
			writeNoGlobals(mesh,FEM_SPARSE+s,nSparse[s]);
	}
	
	void check(void);
};

int dim=15000;

void pushGhost(void) {
     /* Add a layer of node-adjacent ghosts */
     static const int edge2node[]={0,1};
     FEM_Add_ghost_layer(1,1);
     FEM_Add_ghost_elem(0,2,edge2node);
}

// Reinitialize the default writing mesh:
void resetWriteMesh(void) {
  FEM_Mesh_deallocate(FEM_Mesh_default_write());
  FEM_Mesh_set_default_write(FEM_Mesh_allocate());
}

extern "C" void
init(void)
{
  FEM_Print("--------- init called -----------");
  
  int nChunks=FEM_Num_partitions();
  
  //Prepare a new mesh:
  int s;
  int nSparse[nSparseT];
  int nPerSparse=(dim+1)/nSparseT;
  dim=nPerSparse*nSparseT-1; // makes (dim+1) a multiple of nSparseT
  for (s=0;s<nSparseT;s++) nSparse[s]=nPerSparse;
  myMesh m(/*nelems=*/ dim,  /*nnodes=*/ dim+1, nSparse);
  for (s=0;s<nSparseT;s++) nSparse[s]=0;
  for (int n=0;n<m.nnodes;n++) {
    m.nodes[n]=n/(float)m.nnodes;
    s=n/nPerSparse; // Divide domain into contiguous sparse regions
    m.sparse[s][nSparse[s]++]=n;
  }
  for (s=0;s<nSparseT;s++) 
  	if (nSparse[s]!=m.nSparse[s])
		CkAbort("Logic error: sparse counts don't match!");
  for (int e=0;e<m.nelems;e++) {
    m.conn[e*np+0]=e;
    m.conn[e*np+1]=e+1;
  }
  
  //Push the new mesh into the framework:
  m.write(FEM_Mesh_default_write());
  pushGhost();
  
  int c;
#if 1
  //Test out FEM_Serial_split (immediately writes out output files)
  FEM_Print("Calling serial split");
  FEM_Serial_split(nChunks);
  for (c=0;c<nChunks;c++) {
    FEM_Serial_begin(c);
    myMesh m(FEM_Mesh_default_read());
    CkPrintf("  serial split chunk %d> %d nodes, %d elements\n",c,m.nnodes,m.nelems);
  }
  
  resetWriteMesh();
 
  //Test out FEM_Serial_assemble (reads files written by FEM_Serial_split)
  FEM_Print("Calling serial join");
  for (c=0;c<nChunks;c++) {
    FEM_Serial_read(c,nChunks);
    myMesh m(FEM_Mesh_default_read());
    CkPrintf("  serial join chunk %d> %d nodes, %d elements\n",c,m.nnodes,m.nelems);
    m.write(FEM_Mesh_default_write());
  }
  FEM_Serial_assemble();
  mesh_updated(123);
#endif
  FEM_Print("---------- end of init -------------");
}

void testEqual(double is,double shouldBe,const char *what) {
	if (fabs(is-shouldBe)<0.000001) {
		//CkPrintf("[chunk %d] %s test passed.\n",FEM_My_partition(),what);
	} 
	else {/*test failed*/
		CkPrintf("%s test FAILED-- expected %f, got %f (pe %d)\n",
                        what,shouldBe,is,CkMyPe());
		CkAbort("FEM Test failed\n");
	}
}


void myMesh::check(void) {
	// Make sure every node is covered by exactly one sparse record:
	int *marker=new int[nnodes];
	int i,s;
	for (i=0;i<nnodes;i++) marker[i]=0;
	for (s=0;s<nSparseT;s++) {
		for (i=0;i<nSparse[s];i++) 
			marker[sparse[s][i]]++;
	}
	for (i=0;i<nnodes;i++) {
		if (marker[i]<1) CkAbort("Node is missing sparse data!");
		if (marker[i]>1) CkAbort("Node is covered by multiple sparse data!");
	}
	delete[] marker;
}

extern "C" void
driver(void)
{
  int myID=FEM_My_partition();
  if (myID==0) FEM_Print("----------- begin driver ------------");
  for (int loop=0;loop<10;loop++) {
    {
      //Read this mesh out of the framework:
      myMesh m(FEM_Mesh_default_read());

      m.check();
      CkPrintf("    loop %d: chunk %d> %d nodes, %d elements\n",
    	loop,FEM_My_partition(),m.nnodes,m.nelems);
    
      //Prepare mesh to be updated:
      m.write(FEM_Mesh_default_write());
      // m.writeNoGlobals(FEM_Mesh_default_write());
    }
  
    FEM_Update_mesh(mesh_updated,123,FEM_MESH_UPDATE);
    
    if (ENABLE_MIG) { // loop%3==0) { 
      if (myID==0) FEM_Print("----- migrating -----");
      FEM_Migrate();
    }
  }
  if (myID==0) FEM_Print("----------- end driver ------------");
}

extern "C" void
mesh_updated(int param)
{
  CkPrintf("mesh_updated(%d) called.\n",param);
  testEqual(param,123,"mesh_updated param");
  myMesh m(FEM_Mesh_default_read());
  
  CkPrintf("mesh_updated> %d nodes, %d elements\n",m.nnodes,m.nelems);

  for (int n=0;n<m.nnodes;n++) {
    testEqual(m.nodes[n],n/(float)m.nnodes,"node data");
  }
  for (int e=0;e<m.nelems;e++) {
    testEqual(m.conn[e*np+0],e,"element connectivity (col 0)");
    testEqual( m.conn[e*np+1],e+1,"element connectivity (col 1)");
  }
  int s;
  int nPerSparse=(dim+1)/nSparseT;
  for (s=0;s<nSparseT;s++) {
  	testEqual(m.nSparse[s],nPerSparse,"sparse count");
	int first=s*nPerSparse;
	for (int i=0;i<m.nSparse[s];i++) {
		testEqual(m.sparse[s][i],first+i,"sparse connectivity");
	}
  }  
  
  resetWriteMesh();
  m.write(FEM_Mesh_default_write());
  // pushGhost(); //< FIXME: causes duplicate layers!
}
