00001 #include "mpi.h"
00002
00003 #include "emdata.h"
00004
00005 #include "project3d.h"
00006 #include "sirt.h"
00007 #include "utilcomm.h"
00008
00009 using namespace EMAN;
00010
00011 int recons3d_sirt_mpi(MPI_Comm comm , EMData ** images, float * angleshift ,
00012 EMData *& xvol, int nangloc , int radius ,
00013 float lam , int maxit , std::string symmetry,
00014 float tol)
00015 {
00016 int ncpus, mypid, ierr;
00017 double t0;
00018
00019 MPI_Status mpistatus;
00020 MPI_Comm_size(comm,&ncpus);
00021 MPI_Comm_rank(comm,&mypid);
00022
00023 int * psize;
00024 int * nbase;
00025
00026 int nangles;
00027 psize = new int[ncpus];
00028 nbase = new int[ncpus];
00029 MPI_Allreduce(&nangloc, &nangles, 1, MPI_INT, MPI_SUM, comm);
00030
00031
00032 int nsym = 0;
00033
00034 int nx = images[0]->get_xsize();
00035
00036
00037 if ( radius == -1 ) radius = nx/2 - 1;
00038
00039 Vec3i volsize, origin;
00040 volsize[0] = nx;
00041 volsize[1] = nx;
00042 volsize[2] = nx;
00043 origin[0] = nx/2+1;
00044 origin[1] = nx/2+1;
00045 origin[2] = nx/2+1;
00046
00047
00048
00049
00050
00051 xvol->set_size(nx, nx, nx);
00052 xvol->to_zero();
00053 float * voldata = xvol->get_data();
00054
00055
00056 std::vector<float> symangles(3,0.0);
00057
00058
00059 float old_rnorm = 1.00001;
00060
00061 int nrays, nnz;
00062
00063 ierr = getnnz(volsize, radius, origin, &nrays, &nnz);
00064
00065 int * ptrs = new int[nrays+1];
00066 int * cord = new int[3*nrays];
00067
00068 float * xvol_sph = new float[nnz];
00069
00070
00071 ierr = cb2sph(voldata, volsize, radius, origin, nnz, ptrs, cord, xvol_sph);
00072
00073
00074
00075 float * bvol_loc = new float[nnz];
00076
00077 float * bvol = new float[nnz];
00078
00079 float * pxvol = new float[nnz];
00080
00081 float * pxvol_loc = new float[nnz];
00082
00083 for ( int i = 0 ; i < nnz ; ++i ) {
00084 xvol_sph[i] = 0.0;
00085 bvol_loc[i] = 0.0;
00086 bvol[i] = 0.0;
00087 pxvol[i] = 0.0;
00088 pxvol_loc[i] = 0.0;
00089 }
00090
00091 EMData * current_image;
00092 float phi, theta, psi;
00093 Transform3D RA;
00094 Transform3D Tf;
00095 nsym = Tf.get_nsym(symmetry);
00096 Transform3D::EulerType EULER_SPIDER = Transform3D::SPIDER;
00097 Dict angdict;
00098
00099 int iter = 1;
00100
00101 double rnorm = 0.0;
00102 double bnorm = 0.0;
00103 float * grad = new float[nnz];
00104 float * image_data;
00105 float * projected_data = new float[nx*nx];
00106 float dm[8];
00107
00108 int restarts = 0;
00109
00110 t0 = MPI_Wtime();
00111 while (iter <= maxit) {
00112 if ( iter == 1 ) {
00113 if ( restarts == 0 ) {
00114
00115 for ( int i = 0 ; i < nangloc ; ++i ) {
00116 current_image = images[i];
00117 image_data = current_image->get_data();
00118
00119
00120 phi = angleshift[5*i + 0];
00121 theta = angleshift[5*i + 1];
00122 psi = angleshift[5*i + 2];
00123
00124
00125
00126
00127
00128 dm[6] = -angleshift[5*i + 3];
00129 dm[7] = -angleshift[5*i + 4];
00130
00131
00132 RA = Transform3D(EULER_SPIDER, phi, theta, psi);
00133 for ( int ns = 1 ; ns < nsym + 1 ; ++ns ) {
00134
00135
00136 Tf = Tf.get_sym(symmetry, ns) * RA;
00137 angdict = Tf.get_rotation(EULER_SPIDER);
00138 phi = (float) angdict["phi"] * PI/180.0;
00139 theta = (float) angdict["theta"] * PI/180.0;
00140 psi = (float) angdict["psi"] * PI/180.0;
00141 make_proj_mat(phi, theta, psi, dm);
00142
00143 ierr = bckpj3(volsize, nrays, nnz, dm, origin, radius, ptrs, cord,
00144 image_data, bvol_loc);
00145 }
00146 }
00147
00148 ierr = MPI_Allreduce(bvol_loc, bvol, nnz, MPI_FLOAT, MPI_SUM, comm);
00149
00150 }
00151
00152
00153 for ( int j = 0 ; j < nnz ; ++j ) {
00154 bnorm += bvol[j] * (double) bvol[j];
00155 grad[j] = bvol[j];
00156 }
00157 bnorm /= nnz;
00158 bnorm = sqrt(bnorm);
00159
00160 } else {
00161 for ( int i = 0 ; i < nangloc ; ++i ) {
00162
00163 RA = Transform3D(EULER_SPIDER, angleshift[5*i + 0],
00164 angleshift[5*i + 1], angleshift[5*i + 2]);
00165
00166
00167
00168
00169
00170 dm[6] = -angleshift[5*i + 3];
00171 dm[7] = -angleshift[5*i + 4];
00172 for ( int ns = 1 ; ns < nsym + 1 ; ++ns ) {
00173
00174 Tf = Tf.get_sym(symmetry, ns) * RA;
00175 angdict = Tf.get_rotation(EULER_SPIDER);
00176
00177 for ( int j = 0 ; j < nx*nx ; ++j ) {
00178 projected_data[j] = 0.0;
00179 }
00180 phi = (float) angdict["phi"] * PI/180.0;
00181 theta = (float) angdict["theta"] * PI/180.0;
00182 psi = (float) angdict["psi"] * PI/180.0;
00183 make_proj_mat(phi, theta, psi, dm);
00184
00185 ierr = fwdpj3(volsize, nrays, nnz, dm, origin, radius, ptrs,
00186 cord, xvol_sph, projected_data);
00187 ierr = bckpj3(volsize, nrays, nnz, dm, origin, radius, ptrs,
00188 cord, projected_data, pxvol_loc);
00189 }
00190 }
00191
00192 ierr = MPI_Allreduce(pxvol_loc, pxvol, nnz, MPI_FLOAT, MPI_SUM, comm);
00193
00194 for ( int j = 0 ; j < nnz ; ++j ) {
00195 grad[j] = bvol[j];
00196 grad[j] -= pxvol[j];
00197 }
00198 }
00199 rnorm = 0.0;
00200 for ( int j = 0 ; j < nnz ; ++j ) {
00201 rnorm += grad[j]* (double) grad[j];
00202 }
00203 rnorm /= nnz;
00204 rnorm = sqrt(rnorm);
00205 if ( mypid == 0 ) printf("iter = %3d, rnorm / bnorm = %11.3e, rnorm = %11.3e\n",
00206 iter, rnorm / bnorm, rnorm);
00207
00208
00209
00210 if ( rnorm / bnorm > old_rnorm ) {
00211
00212 if ( restarts > 20 ) {
00213 if ( mypid == 0 )
00214 printf("Failure to converge, even with lam = %11.3e\n", lam);
00215 break;
00216 } else {
00217 ++restarts;
00218 iter = 1;
00219 lam /= 2.0;
00220
00221
00222 old_rnorm = 1.0001;
00223 for ( int j = 0 ; j < nnz ; ++j ) {
00224 xvol_sph[j] = 0.0;
00225 pxvol_loc[j] = 0.0;
00226 }
00227 if ( mypid == 0 ) printf("reducing lam to %11.3e, restarting\n", lam);
00228 continue;
00229 }
00230 }
00231
00232 if ( rnorm / bnorm < tol || rnorm / bnorm > old_rnorm ) {
00233 if ( mypid == 0 )
00234 printf("Terminating with rnorm/bnorm = %11.3e, tol = %11.3e, ");
00235 printf("old_rnorm = %11.3e\n", rnorm/bnorm, tol, old_rnorm);
00236 break;
00237 }
00238
00239 old_rnorm = rnorm / bnorm;
00240
00241 for ( int j = 0 ; j < nnz ; ++j ) {
00242 xvol_sph[j] += lam * grad[j];
00243
00244 pxvol_loc[j] = 0.0;
00245 }
00246
00247 ++iter;
00248 }
00249 if (mypid == 0) printf("Total time in SIRT = %11.3e\n", MPI_Wtime()-t0);
00250
00251
00252 ierr = sph2cb(xvol_sph, volsize, nrays, radius, nnz, ptrs, cord, voldata);
00253
00254 EMDeleteArray(grad);
00255 EMDeleteArray(pxvol);
00256 EMDeleteArray(bvol);
00257 EMDeleteArray(pxvol_loc);
00258 EMDeleteArray(bvol_loc);
00259
00260 EMDeleteArray(ptrs);
00261 EMDeleteArray(cord);
00262 EMDeleteArray(xvol_sph);
00263 EMDeleteArray(projected_data);
00264
00265 EMDeleteArray(psize);
00266 EMDeleteArray(nbase);
00267
00268 return 0;
00269 }
00270