00001 #include "mpi.h"
00002
00003 #include "emdata.h"
00004
00005 #include "sirt_Cart.h"
00006 #include "utilcomm_Cart.h"
00007 #include "project3d_Cart.h"
00008 #include "project3d.h"
00009
00010 using namespace EMAN;
00011 #define ROW 0
00012 #define COL 1
00013
00014
00015
00016 int recons3d_sirt_mpi_Cart(MPI_Comm comm_2d , MPI_Comm comm_row,
00017 MPI_Comm comm_col , EMData ** images ,
00018 float * angleshift , EMData *& xvol ,
00019 int nangloc , int radius ,
00020 float lam , int maxit ,
00021 std::string symmetry, float tol)
00022 {
00023 MPI_Status mpistatus;
00024 int ncpus, my2dpid, ierr;
00025 int mycoords[2], dims[2], periods[2];
00026 int srpid, srcoords[2];
00027 double t0;
00028
00029
00030 MPI_Cart_get(comm_2d, 2, dims, periods, mycoords);
00031 MPI_Comm_rank(comm_2d, &my2dpid);
00032 ncpus = dims[ROW]*dims[COL];
00033
00034 int * psize;
00035 int * nbase;
00036
00037 int nangles;
00038 psize = new int[dims[ROW]];
00039 nbase = new int[dims[ROW]];
00040 MPI_Allreduce(&nangloc, &nangles, 1, MPI_INT, MPI_SUM, comm_col);
00041
00042 int nsym = 0;
00043
00044 int nx = images[0]->get_xsize();
00045
00046
00047 if ( radius == -1 ) radius = nx/2 - 1;
00048
00049 Vec3i volsize, origin;
00050 volsize[0] = nx;
00051 volsize[1] = nx;
00052 volsize[2] = nx;
00053 origin[0] = nx/2+1;
00054 origin[1] = nx/2+1;
00055 origin[2] = nx/2+1;
00056
00057
00058
00059
00060
00061
00062 std::vector<float> symangles(3,0.0);
00063
00064
00065 float old_rnorm = 1.00001;
00066
00067
00068
00069
00070
00071
00072
00073 int nrays, nnz;
00074
00075 ierr = getnnz(volsize, radius, origin, &nrays, &nnz);
00076
00077 int nnzloc, nraysloc;
00078 int * ptrs = new int[nrays+1];
00079 int * cord = new int[3*nrays];
00080 int *nnzpart = new int[dims[COL]];
00081 int *nnzbase = new int[dims[COL]+1];
00082 int *ptrstart = new int[dims[COL]+1];
00083
00084 ierr = getcb2sph(volsize, radius, origin, nnz, ptrs, cord);
00085 nnzloc = setpart_gr1(comm_2d, nnz, nnzpart, nnzbase);
00086 nraysloc = sphpart(comm_2d, nrays, ptrs, nnzbase, ptrstart);
00087
00088 int myptrstart = ptrstart[mycoords[COL]];
00089 int nnzall[dims[COL]];
00090 for (int i = 0; i<dims[COL]; i++)
00091 nnzall[i] = ptrs[ptrstart[i+1]] - ptrs[ptrstart[i]];
00092
00093 nnzloc = nnzall[mycoords[COL]];
00094
00095 float *bvol_loc = new float[nnzloc];
00096 float *bvol = new float[nnzloc];
00097 float *xvol_sphloc = new float[nnzloc];
00098 float *pxvol_loc = new float[nnzloc];
00099 float *pxvol = new float[nnzloc];
00100 float * grad_loc = new float[nnzloc];
00101 for (int i=0; i< nnzloc; i++){
00102 xvol_sphloc[i] = 0.0;
00103 bvol[i] = 0.0;
00104 bvol_loc[i] = 0.0;
00105 pxvol_loc[i] = 0.0;
00106 pxvol[i] = 0.0;
00107 grad_loc[i] = 0.0;
00108 }
00109
00110 EMData * current_image;
00111 float phi, theta, psi;
00112 Transform3D RA;
00113 Transform3D Tf;
00114 nsym = Tf.get_nsym(symmetry);
00115 Transform3D::EulerType EULER_SPIDER = Transform3D::SPIDER;
00116 Dict angdict;
00117
00118 int iter = 1;
00119
00120 double rnorm = 0.0, rnorm_loc = 0.0;
00121 double bnorm = 0.0, bnorm_loc = 0.0;
00122
00123 float * image_data;
00124 float * projected_data_loc = new float[nangloc*nx*nx];
00125 float * projected_data = new float[nangloc*nx*nx];
00126
00127 float dm[8];
00128
00129 int restarts = 0;
00130
00131 t0 = MPI_Wtime();
00132
00133 while (iter <= maxit) {
00134 if ( iter == 1 ) {
00135 if ( restarts == 0 ) {
00136
00137
00138 for ( int i = 0 ; i < nangloc ; ++i ) {
00139 current_image = images[i];
00140 image_data = current_image->get_data();
00141
00142
00143 phi = angleshift[5*i + 0];
00144 theta = angleshift[5*i + 1];
00145 psi = angleshift[5*i + 2];
00146
00147
00148
00149
00150
00151 dm[6] = -angleshift[5*i + 3];
00152 dm[7] = -angleshift[5*i + 4];
00153
00154 RA = Transform3D(EULER_SPIDER, phi, theta, psi);
00155 for ( int ns = 1 ; ns < nsym + 1 ; ++ns ) {
00156
00157
00158 Tf = Tf.get_sym(symmetry, ns) * RA;
00159 angdict = Tf.get_rotation(EULER_SPIDER);
00160 phi = (float) angdict["phi"] * PI/180.0;
00161 theta = (float) angdict["theta"] * PI/180.0;
00162 psi = (float) angdict["psi"] * PI/180.0;
00163 make_proj_mat(phi, theta, psi, dm);
00164
00165 ierr = bckpj3_Cart(volsize, nraysloc, nnzloc, dm,
00166 origin, radius, ptrs, cord,
00167 myptrstart, image_data, bvol_loc);
00168 }
00169 }
00170
00171
00172
00173
00174 ierr = MPI_Allreduce (bvol_loc, bvol, nnzloc, MPI_FLOAT,
00175 MPI_SUM, comm_col);
00176
00177 }
00178
00179
00180 bnorm_loc = 0.0;
00181 for ( int j = 0 ; j < nnzloc ; ++j ) {
00182 bnorm_loc += bvol[j] * (double) bvol[j];
00183 grad_loc[j]= bvol[j];
00184 }
00185 ierr = MPI_Allreduce (&bnorm_loc, &bnorm, 1, MPI_DOUBLE, MPI_SUM,
00186 comm_row);
00187
00188 bnorm /= nnz;
00189 bnorm = sqrt(bnorm);
00190
00191 } else {
00192
00193
00194 for ( int ns = 1 ; ns < nsym + 1 ; ++ns ) {
00195
00196 for (int i=0; i<nangloc*nx*nx; i++){
00197 projected_data_loc[i] = 0.0;
00198 projected_data[i] = 0.0;
00199 }
00200
00201
00202 for ( int i = 0 ; i < nangloc ; ++i ) {
00203
00204 RA = Transform3D(EULER_SPIDER, angleshift[5*i + 0],
00205 angleshift[5*i + 1], angleshift[5*i + 2]);
00206
00207
00208
00209
00210
00211 dm[6] = -angleshift[5*i + 3];
00212 dm[7] = -angleshift[5*i + 4];
00213
00214
00215 Tf = Tf.get_sym(symmetry, ns) * RA;
00216 angdict = Tf.get_rotation(EULER_SPIDER);
00217
00218 phi = (float) angdict["phi"] * PI/180.0;
00219 theta = (float) angdict["theta"] * PI/180.0;
00220 psi = (float) angdict["psi"] * PI/180.0;
00221 make_proj_mat(phi, theta, psi, dm);
00222
00223 ierr = fwdpj3_Cart(volsize, nraysloc, nnzloc, dm,
00224 origin, radius, ptrs, cord, myptrstart,
00225 xvol_sphloc, &projected_data_loc[nx*nx*i]);
00226 }
00227
00228
00229
00230 ierr = MPI_Allreduce(projected_data_loc, projected_data,
00231 nangloc*nx*nx, MPI_FLOAT, MPI_SUM, comm_row);
00232
00233
00234 for ( int i = 0 ; i < nangloc ; ++i ) {
00235
00236 RA = Transform3D(EULER_SPIDER, angleshift[5*i + 0],
00237 angleshift[5*i + 1], angleshift[5*i + 2]);
00238
00239
00240
00241
00242
00243 dm[6] = -angleshift[5*i + 3];
00244 dm[7] = -angleshift[5*i + 4];
00245
00246
00247 Tf = Tf.get_sym(symmetry, ns) * RA;
00248 angdict = Tf.get_rotation(EULER_SPIDER);
00249
00250
00251 phi = (float) angdict["phi"] * PI/180.0;
00252 theta = (float) angdict["theta"] * PI/180.0;
00253 psi = (float) angdict["psi"] * PI/180.0;
00254 make_proj_mat(phi, theta, psi, dm);
00255
00256
00257 ierr = bckpj3_Cart(volsize, nraysloc, nnzloc, dm, origin,
00258 radius, ptrs, cord, myptrstart,
00259 &projected_data[nx*nx*i], pxvol_loc);
00260 }
00261 }
00262
00263 ierr = MPI_Allreduce(pxvol_loc, pxvol, nnzloc, MPI_FLOAT, MPI_SUM,
00264 comm_col);
00265
00266 for ( int j = 0 ; j < nnzloc ; ++j ) {
00267 grad_loc[j] = bvol[j];
00268 grad_loc[j] -= pxvol[j];
00269 }
00270 }
00271
00272 rnorm_loc = 0.0;
00273 for ( int j = 0 ; j < nnzloc ; ++j ) {
00274 rnorm_loc += grad_loc[j]* (double) grad_loc[j];
00275 }
00276
00277
00278 ierr = MPI_Allreduce (&rnorm_loc, &rnorm, 1, MPI_DOUBLE, MPI_SUM,
00279 comm_row);
00280 rnorm /= nnz;
00281 rnorm = sqrt(rnorm);
00282 if ( my2dpid == 0 )
00283 printf("iter = %3d, rnorm / bnorm = %11.3e, rnorm = %11.3e\n",
00284 iter, rnorm / bnorm, rnorm);
00285
00286
00287
00288
00289 if ( rnorm / bnorm > old_rnorm ) {
00290
00291 if ( restarts > 20 ) {
00292 if ( my2dpid == 0 )
00293 printf("Failure to converge, even with lam = %f\n", lam);
00294 break;
00295 } else {
00296 ++restarts;
00297 iter = 1;
00298 lam /= 2.0;
00299
00300
00301 old_rnorm = 1.0001;
00302 for ( int j = 0 ; j < nnzloc ; ++j ) {
00303 xvol_sphloc[j] = 0.0;
00304 pxvol_loc[j] = 0.0;
00305 }
00306 if ( my2dpid == 0 )
00307 printf("reducing lam to %11.3e, restarting\n", lam);
00308 continue;
00309 }
00310 }
00311
00312
00313
00314 if ( rnorm / bnorm < tol || rnorm / bnorm > old_rnorm ) {
00315 if ( my2dpid == 0 )
00316 printf("Terminating with rnorm/bnorm = %11.3e, ");
00317 printf("tol = %11.3e, old_rnorm = %11.3e\n",
00318 rnorm/bnorm, tol, old_rnorm);
00319 break;
00320 }
00321
00322 old_rnorm = rnorm / bnorm;
00323
00324
00325 for ( int j = 0 ; j < nnzloc ; ++j ) {
00326 xvol_sphloc[j] += lam * grad_loc[j];
00327
00328 pxvol_loc[j] = 0.0;
00329 }
00330 ++iter;
00331 }
00332
00333 if (my2dpid == 0) printf("Total time in SIRT = %11.3e\n", MPI_Wtime()-t0);
00334
00335
00336
00337 if (mycoords[ROW] == 0 && mycoords[COL] != 0 ){
00338 srcoords[ROW] = 0;
00339 srcoords[COL] = 0;
00340 MPI_Cart_rank(comm_2d, srcoords, &srpid);
00341 MPI_Send(xvol_sphloc, nnzloc, MPI_FLOAT, 0, my2dpid, comm_2d);
00342 }
00343
00344 if (mycoords[ROW] == 0 && mycoords[COL] == 0 ){
00345 xvol->set_size(nx, nx, nx);
00346 xvol->to_zero();
00347 float * voldata = xvol->get_data();
00348 float * xvol_sph = new float[nnz];
00349
00350 for(int i=0; i<nnzloc; i++)
00351 xvol_sph[i] = xvol_sphloc[i];
00352
00353 for(int i=1; i< dims[COL]; i++){
00354 srcoords[ROW] = 0;
00355 srcoords[COL] = i;
00356 MPI_Cart_rank(comm_2d, srcoords, &srpid);
00357 MPI_Recv(&xvol_sph[ptrs[ptrstart[i]]-1], nnzall[i], MPI_FLOAT,
00358 srpid, srpid, comm_2d, &mpistatus);
00359 }
00360
00361
00362 ierr = sph2cb(xvol_sph, volsize, nrays, radius, nnz, ptrs, cord, voldata);
00363 EMDeleteArray(xvol_sph);
00364 }
00365 EMDeleteArray(grad_loc);
00366 EMDeleteArray(pxvol_loc);
00367 EMDeleteArray(bvol_loc);
00368
00369 EMDeleteArray(ptrs);
00370 EMDeleteArray(cord);
00371
00372
00373 EMDeleteArray(psize);
00374 EMDeleteArray(nbase);
00375
00376 delete [] projected_data_loc;
00377
00378 return 0;
00379 }
00380