00001
00002
00003
00004
00005
00006
00007
00008
00009 #ifndef AMROC_AMRINTERPOLATION_H
00010 #define AMROC_AMRINTERPOLATION_H
00011
00018 #include "Geom.h"
00019 #include "Interpolation.h"
00020
00027 template <class VectorType, int dim>
00028 class AMRInterpolation : public Geom<typename VectorType::InternalDataType,dim>,
00029 public AMRBase<VectorType,dim> {
00030 typedef typename VectorType::InternalDataType DataType;
00031 typedef AMRBase<VectorType,dim> base;
00032 typedef Geom<DataType,dim> geom_base;
00033
00034 public:
00035 typedef typename base::vec_grid_fct_type vec_grid_fct_type;
00036 typedef typename base::vec_grid_data_type vec_grid_data_type;
00037 typedef GridFunction<DataType,dim> grid_fct_type;
00038 typedef GridData<DataType,dim> grid_data_type;
00039 typedef Vector<DataType,dim> point_type;
00040 typedef Interpolation<VectorType,dim> interpolation_type;
00041
00042 AMRInterpolation() : base(), _ErrorValue(-1.e37) {
00043 _interpolation = new interpolation_type();
00044 }
00045
00046 virtual ~AMRInterpolation() {
00047 if (_interpolation) delete _interpolation;
00048 }
00049
00050 void PointsValues(vec_grid_fct_type& u, const double& t, const int& Npoints,
00051 const point_type* xc, VectorType* uv, const point_type* xct=0) {
00052
00053 if (Npoints<=0 || !uv) return;
00054 for (register int n=0; n<Npoints; n++)
00055 uv[n] = _ErrorValue;
00056
00057 for (register int l=0; l<=FineLevel(base::GH()); l++) {
00058 int Time = CurrentTime(base::GH(),l);
00059 PointsValues(u,Time,l,t,Npoints,xc,uv,xct);
00060 }
00061 }
00062
00063 void PointsValues(vec_grid_fct_type& u, const int& Time, const int& Level, const double& t,
00064 const int& Npoints, const point_type* xc, VectorType* uv, const point_type* xct=0) {
00065
00066 if (Npoints<=0 || !uv) return;
00067 bool* lused = new bool[Npoints];
00068
00069 forall (u,Time,Level,c)
00070 BBox bb = u.interiorbbox(Time,Level,c);
00071 int ncl;
00072 if (xct)
00073 ncl = geom_base::LocalList(base::GH(),bb,Npoints,xct,lused);
00074 else
00075 ncl = geom_base::LocalList(base::GH(),bb,Npoints,xc,lused);
00076 if (ncl>0) {
00077 register int n, nl;
00078 point_type* xcl = new point_type[ncl];
00079 VectorType* uvl = new VectorType[ncl];
00080 for (n=0, nl=0; n<Npoints; n++)
00081 if (lused[n]) xcl[nl++] = xc[n];
00082
00083 GetGrid(u(Time,Level,c),Time,Level,c,t,ncl,xcl,uvl);
00084
00085 for (n=0, nl=0; n<Npoints; n++)
00086 if (lused[n]) uv[n] = uvl[nl++];
00087
00088 delete [] xcl;
00089 delete [] uvl;
00090 }
00091 end_forall
00092
00093 delete [] lused;
00094 }
00095
00096 int LocalPoints(vec_grid_fct_type& u, const int& Time, const int& Level,
00097 const int& Npoints, const point_type* xc, bool* lused) {
00098 int ncl = 0;
00099 if (Npoints<=0 || !xc) return ncl;
00100
00101 for (register int n=0; n<Npoints; n++)
00102 lused[n] = false;
00103
00104 bool* lused_help = new bool[Npoints];
00105 forall (u,Time,Level,c)
00106 BBox bb = u.interiorbbox(Time,Level,c);
00107 int ncl_help = geom_base::LocalList(base::GH(),bb,Npoints,xc,lused_help);
00108 if (ncl_help>0)
00109 for (register int n=0; n<Npoints; n++)
00110 if (lused_help[n]) lused[n] = true;
00111 end_forall
00112 delete [] lused_help;
00113
00114 for (register int n=0; n<Npoints; n++)
00115 if (lused[n]) ncl++;
00116 return ncl;
00117 }
00118
00119 virtual void GetGrid(vec_grid_data_type& gdu, const int& Time, const int& Level,
00120 const int& c, const double& t, const int& Npoints,
00121 const point_type* xcl, VectorType* uvl) {
00122 if (_interpolation)
00123 Interpolation_().Interpolate(base::GH(),gdu,Npoints,xcl,uvl,_ErrorValue);
00124 }
00125
00126 void PointsValuesPar(vec_grid_fct_type& u, const double& t, const int& Npoints,
00127 const point_type* xc, VectorType* uv) {
00128 if (!uv) return;
00129 for (register int n=0; n<Npoints; n++)
00130 uv[n] = _ErrorValue;
00131
00132 int Np,Npr;
00133 int num = comm_service::proc_num();
00134 point_type *xcp = 0;
00135 PointsAllGather(Npoints,xc,Np,xcp);
00136 VectorType* uvp = new VectorType[Np*num];
00137
00138 PointsValues(u,t,Np*num,xcp,uvp);
00139 delete [] xcp;
00140
00141 VectorType* uvpr = 0;
00142 DataAllScatter(Np,uvp,Npr,uvpr);
00143 for (int p=0; p<num; p++)
00144 for (register int n=0; n<Npoints; n++)
00145 if (uvpr[p*Npr+n] != _ErrorValue)
00146 uv[n] = uvpr[p*Npr+n];
00147
00148 delete [] uvpr;
00149 delete [] uvp;
00150 }
00151
00152 void PointsValues(grid_fct_type& u, const double& t, const int& Npoints,
00153 const point_type* xc, DataType* uv, const point_type* xct=0) {
00154
00155 if (Npoints<=0 || !uv) return;
00156 for (register int n=0; n<Npoints; n++)
00157 uv[n] = _ErrorValue;
00158
00159 for (register int l=0; l<=FineLevel(base::GH()); l++) {
00160 int Time = CurrentTime(base::GH(),l);
00161 PointsValues(u,Time,l,t,Npoints,xc,uv,xct);
00162 }
00163 }
00164
00165 void PointsValues(grid_fct_type& u, const int& Time, const int& Level, const double& t,
00166 const int& Npoints, const point_type* xc, DataType* uv, const point_type* xct=0) {
00167
00168 if (Npoints<=0 || !uv) return;
00169 bool* lused = new bool[Npoints];
00170
00171 forall (u,Time,Level,c)
00172 BBox bb = u.interiorbbox(Time,Level,c);
00173 int ncl;
00174 if (xct)
00175 ncl = geom_base::LocalList(base::GH(),bb,Npoints,xct,lused);
00176 else
00177 ncl = geom_base::LocalList(base::GH(),bb,Npoints,xc,lused);
00178 if (ncl>0) {
00179 register int n, nl;
00180 point_type* xcl = new point_type[ncl];
00181 DataType* uvl = new DataType[ncl];
00182 for (n=0, nl=0; n<Npoints; n++)
00183 if (lused[n]) xcl[nl++] = xc[n];
00184
00185 GetGrid(u(Time,Level,c),Time,Level,c,t,ncl,xcl,uvl);
00186
00187 for (n=0, nl=0; n<Npoints; n++)
00188 if (lused[n]) uv[n] = uvl[nl++];
00189
00190 delete [] xcl;
00191 delete [] uvl;
00192 }
00193 end_forall
00194
00195 delete [] lused;
00196 }
00197
00198 int LocalPoints(grid_fct_type& u, const int& Time, const int& Level,
00199 const int& Npoints, const point_type* xc, bool* lused) {
00200 int ncl = 0;
00201 if (Npoints<=0 || !xc) return ncl;
00202
00203 for (register int n=0; n<Npoints; n++)
00204 lused[n] = false;
00205
00206 bool* lused_help = new bool[Npoints];
00207 forall (u,Time,Level,c)
00208 BBox bb = u.interiorbbox(Time,Level,c);
00209 int ncl_help = geom_base::LocalList(base::GH(),bb,Npoints,xc,lused_help);
00210 if (ncl_help>0)
00211 for (register int n=0; n<Npoints; n++)
00212 if (lused_help[n]) lused[n] = true;
00213 end_forall
00214 delete [] lused_help;
00215
00216 for (register int n=0; n<Npoints; n++)
00217 if (lused[n]) ncl++;
00218 return ncl;
00219 }
00220
00221 virtual void GetGrid(grid_data_type& gdu, const int& Time, const int& Level,
00222 const int& c, const double& t, const int& Npoints,
00223 const point_type* xcl, DataType* uvl) {
00224 if (_interpolation)
00225 Interpolation_().Interpolate(base::GH(),gdu,Npoints,xcl,uvl,_ErrorValue);
00226 }
00227
00228 void PointsValuesPar(grid_fct_type& u, const double& t, const int& Npoints,
00229 const point_type* xc, DataType* uv) {
00230 if (!uv) return;
00231 for (register int n=0; n<Npoints; n++)
00232 uv[n] = _ErrorValue;
00233
00234 int Np,Npr;
00235 int num = comm_service::proc_num();
00236 point_type* xcp = 0;
00237 PointsAllGather(Npoints,xc,Np,xcp);
00238 DataType* uvp = new DataType[Np*num];
00239
00240 PointsValues(u,t,Np*num,xcp,uvp);
00241 delete [] xcp;
00242
00243 DataType* uvpr = 0;
00244 DataAllScatter(Np,uvp,Npr,uvpr);
00245 for (int p=0; p<num; p++)
00246 for (register int n=0; n<Npoints; n++)
00247 if (uvpr[p*Npr+n] != _ErrorValue)
00248 uv[n] = uvpr[p*Npr+n];
00249
00250 delete [] uvpr;
00251 delete [] uvp;
00252 }
00253
00254 void PointsAllGather(int Nsnd, const point_type* snd, int& Nrcv, point_type*& rcv) {
00255 if (rcv) delete [] rcv;
00256
00257 if (!comm_service::dce() || comm_service::proc_world() == 1) {
00258 Nrcv = Nsnd;
00259 rcv = new point_type[Nrcv];
00260 for (register int i=0; i<Nsnd; i++) rcv[i] = snd[i];
00261 return;
00262 }
00263
00264 #ifdef DAGH_NO_MPI
00265 #else
00266 int num = comm_service::proc_num();
00267 int R = MPI_Allreduce(&Nsnd, &Nrcv, 1, MPI_INT, MPI_MAX, comm_service::comm());
00268 if ( MPI_SUCCESS != R )
00269 comm_service::error_die("AMRInterpolation::ArrayAllGather","MPI_Allreduce",R);
00270 rcv = new point_type[Nrcv*num];
00271 point_type* tmpsnd = new point_type[Nrcv*num];
00272 for (register int i=0; i<Nsnd; i++) tmpsnd[i] = snd[i];
00273 for (register int i=Nsnd; i<Nrcv; i++) tmpsnd[i] = _ErrorValue;
00274 int sndsize = Nrcv*sizeof(point_type);
00275 int rcvsize = Nrcv*sizeof(point_type);
00276 R = MPI_Allgather((void *)tmpsnd, sndsize, MPI_BYTE, (void *)rcv, rcvsize,
00277 MPI_BYTE, comm_service::comm());
00278 if ( MPI_SUCCESS != R )
00279 comm_service::error_die( "AMRInterpolation::PointsAllGather", "MPI_Allgather", R );
00280 delete [] tmpsnd;
00281 #endif
00282 }
00283
00284 void DataAllScatter(int Nsnd, const VectorType* snd, int& Nrcv, VectorType*& rcv) {
00285 Nrcv = Nsnd;
00286 if (rcv) delete [] rcv;
00287
00288 if (!comm_service::dce() || comm_service::proc_world() == 1) {
00289 rcv = new VectorType[Nrcv];
00290 for (register int i=0; i<Nsnd; i++) rcv[i] = snd[i];
00291 return;
00292 }
00293
00294 #ifdef DAGH_NO_MPI
00295 #else
00296 int num = comm_service::proc_num();
00297 rcv = new VectorType[Nrcv*num];
00298 int sndsize = Nsnd*sizeof(VectorType);
00299 int rcvsize = Nrcv*sizeof(VectorType);
00300 int R = MPI_Alltoall((void*) snd, sndsize, MPI_BYTE, (void *)rcv, rcvsize,
00301 MPI_BYTE, comm_service::comm());
00302 if ( MPI_SUCCESS != R )
00303 comm_service::error_die("AMRInterpolation::DataAllScatter","MPI_Alltoall",R);
00304 #endif
00305 }
00306
00307 void DataAllScatter(int Nsnd, const DataType* snd, int& Nrcv, DataType*& rcv) {
00308 Nrcv = Nsnd;
00309 if (rcv) delete [] rcv;
00310
00311 if (!comm_service::dce() || comm_service::proc_world() == 1) {
00312 rcv = new DataType[Nrcv];
00313 for (register int i=0; i<Nsnd; i++) rcv[i] = snd[i];
00314 return;
00315 }
00316
00317 #ifdef DAGH_NO_MPI
00318 #else
00319 int num = comm_service::proc_num();
00320 rcv = new DataType[Nrcv*num];
00321 int sndsize = Nsnd*sizeof(DataType);
00322 int rcvsize = Nrcv*sizeof(DataType);
00323 int R = MPI_Alltoall((void*) snd, sndsize, MPI_BYTE, (void *)rcv, rcvsize,
00324 MPI_BYTE, comm_service::comm());
00325 if ( MPI_SUCCESS != R )
00326 comm_service::error_die("AMRInterpolation::DataAllScatter","MPI_Alltoall",R);
00327 #endif
00328 }
00329
00330 void ArrayCombine(const int& TargetNode, const int& Nvalues, DataType* dat) {
00331 #ifdef DAGH_NO_MPI
00332 #else
00333 if (comm_service::dce() && comm_service::proc_num() > 1) {
00334 int num = comm_service::proc_num();
00335 MPI_Status status;
00336 int R;
00337 int size = sizeof(DataType)*Nvalues;
00338 int tag = 10000;
00339 if (MY_PROC!=TargetNode) {
00340 R = MPI_Send((void *)dat, size, MPI_BYTE, TargetNode, tag, comm_service::comm());
00341 if ( MPI_SUCCESS != R )
00342 comm_service::error_die( "AMRInterpolation::ArrayCombine", "MPI_Send", R );
00343 }
00344 else
00345 for (int proc=0; proc<num; proc++) {
00346 if (proc==TargetNode) continue;
00347 DataType* datp = new DataType[Nvalues];
00348 R = MPI_Recv((void *)datp, size, MPI_BYTE, proc, tag, comm_service::comm(), &status);
00349 if ( MPI_SUCCESS != R )
00350 comm_service::error_die( "AMRInterpolation::ArrayCombine", "MPI_Recv", R );
00351 for (register int n=0; n<Nvalues; n++)
00352 if (datp[n] != _ErrorValue)
00353 dat[n] = datp[n];
00354 delete [] datp;
00355 }
00356 }
00357 #endif
00358 }
00359
00360 inline void SetInterpolation(interpolation_type* inter) {
00361 if (_interpolation) delete _interpolation;
00362 _interpolation = inter;
00363 }
00364 inline interpolation_type& Interpolation_() { return *_interpolation; }
00365 inline const interpolation_type& Interpolation_() const { return *_interpolation; }
00366
00367 inline void SetErrorValue(const DataType val) { _ErrorValue = val; }
00368 inline DataType ErrorValue() const { return _ErrorValue; }
00369
00370 protected:
00371 DataType _ErrorValue;
00372 interpolation_type* _interpolation;
00373 };
00374
00375 #endif