#include "mpi.h"
Include dependency graph for sirt.h:
This graph shows which files directly or indirectly include this file:
Go to the source code of this file.
Defines | |
#define | PI 3.141592653589793 |
Functions | |
int | recons3d_sirt_mpi (MPI_Comm comm, EMData **images, float *angleshift, EMData *&xvol, int nangloc, int radius=-1, float lam=1.0e-4, int maxit=100, std::string symmetry="c1", float tol=1.0e-3) |
|
|
|
Definition at line 11 of file sirt.cpp. References bckpj3(), cb2sph(), cord, dm, EMDeleteArray(), fwdpj3(), EMAN::EMData::get_data(), EMAN::EMData::get_xsize(), getnnz(), grad, ierr, images, make_proj_mat(), nnz, nrays, nx, phi, PI, ptrs, EMAN::EMData::set_size(), sph2cb(), sqrt(), theta, EMAN::EMData::to_zero(), and EMAN::Vec3i. Referenced by ali3d_d(), and main(). 00015 { 00016 int ncpus, mypid, ierr; 00017 double t0; 00018 00019 MPI_Status mpistatus; 00020 MPI_Comm_size(comm,&ncpus); 00021 MPI_Comm_rank(comm,&mypid); 00022 00023 int * psize; 00024 int * nbase; 00025 00026 int nangles; 00027 psize = new int[ncpus]; 00028 nbase = new int[ncpus]; 00029 MPI_Allreduce(&nangloc, &nangles, 1, MPI_INT, MPI_SUM, comm); 00030 00031 00032 int nsym = 0; 00033 // get image size from first image 00034 int nx = images[0]->get_xsize(); 00035 00036 // make radius as large as possible if the user didn't provide one 00037 if ( radius == -1 ) radius = nx/2 - 1; 00038 00039 Vec3i volsize, origin; 00040 volsize[0] = nx; 00041 volsize[1] = nx; 00042 volsize[2] = nx; 00043 origin[0] = nx/2+1; 00044 origin[1] = nx/2+1; 00045 origin[2] = nx/2+1; 00046 00047 // this is not currently needed, because the stack that gets passed to sirt 00048 // will have its background subtracted already 00049 // ierr = CleanStack(comm, images, nangloc, radius, volsize, origin); 00050 00051 xvol->set_size(nx, nx, nx); 00052 xvol->to_zero(); 00053 float * voldata = xvol->get_data(); 00054 00055 // vector of symmetrized angles 00056 std::vector<float> symangles(3,0.0); 00057 00058 // kluge, make sure if its 1.0 + epsilon it still works; 00059 float old_rnorm = 1.00001; 00060 00061 int nrays, nnz; 00062 00063 ierr = getnnz(volsize, radius, origin, &nrays, &nnz); 00064 00065 int * ptrs = new int[nrays+1]; 00066 int * cord = new int[3*nrays]; 00067 00068 float * xvol_sph = new float[nnz]; 00069 00070 // this is just to set ptrs and cord, voldata is all 0.0 at this point 00071 ierr = cb2sph(voldata, volsize, radius, origin, nnz, ptrs, cord, xvol_sph); 00072 00073 // arrays to hold volume data: 00074 // initial backprojected volume, local 00075 float * bvol_loc = new float[nnz]; 00076 // initial backprojected volume, global 00077 float * bvol = new float[nnz]; 00078 // P^T * P * xvol, local 00079 float * pxvol = new float[nnz]; 00080 // P^T * P * xvol, global 00081 float * pxvol_loc = new float[nnz]; 00082 00083 for ( int i = 0 ; i < nnz ; ++i ) { 00084 xvol_sph[i] = 0.0; 00085 bvol_loc[i] = 0.0; 00086 bvol[i] = 0.0; 00087 pxvol[i] = 0.0; 00088 pxvol_loc[i] = 0.0; 00089 } 00090 00091 EMData * current_image; 00092 float phi, theta, psi; 00093 Transform3D RA; 00094 Transform3D Tf; 00095 nsym = Tf.get_nsym(symmetry); 00096 Transform3D::EulerType EULER_SPIDER = Transform3D::SPIDER; 00097 Dict angdict; 00098 00099 int iter = 1; 00100 00101 double rnorm = 0.0; 00102 double bnorm = 0.0; 00103 float * grad = new float[nnz]; 00104 float * image_data; 00105 float * projected_data = new float[nx*nx]; 00106 float dm[8]; 00107 00108 int restarts = 0; 00109 00110 t0 = MPI_Wtime(); 00111 while (iter <= maxit) { 00112 if ( iter == 1 ) { 00113 if ( restarts == 0 ) { 00114 // only do this if we aren't restarting due to lam being too large 00115 for ( int i = 0 ; i < nangloc ; ++i ) { 00116 current_image = images[i]; 00117 image_data = current_image->get_data(); 00118 // retrieve the angles and shifts associated with each image 00119 // from the array angleshift. 00120 phi = angleshift[5*i + 0]; 00121 theta = angleshift[5*i + 1]; 00122 psi = angleshift[5*i + 2]; 00123 00124 // need to change signs here because the input shifts 00125 // are shifts associated with 2-D images. Because 00126 // the projection operator actually shifts the volume 00127 // the signs should be negated here 00128 dm[6] = -angleshift[5*i + 3]; 00129 dm[7] = -angleshift[5*i + 4]; 00130 00131 // make an instance of Transform3D with the angles 00132 RA = Transform3D(EULER_SPIDER, phi, theta, psi); 00133 for ( int ns = 1 ; ns < nsym + 1 ; ++ns ) { 00134 // compose it with each symmetry rotation in turn 00135 // shifts (stored in dm[6] and dm[7] remain fixed 00136 Tf = Tf.get_sym(symmetry, ns) * RA; 00137 angdict = Tf.get_rotation(EULER_SPIDER); 00138 phi = (float) angdict["phi"] * PI/180.0; 00139 theta = (float) angdict["theta"] * PI/180.0; 00140 psi = (float) angdict["psi"] * PI/180.0; 00141 make_proj_mat(phi, theta, psi, dm); 00142 // accumulate the back-projected images in bvol_loc 00143 ierr = bckpj3(volsize, nrays, nnz, dm, origin, radius, ptrs, cord, 00144 image_data, bvol_loc); 00145 } 00146 } 00147 // reduce bvol_loc so each processor has the full volume 00148 ierr = MPI_Allreduce(bvol_loc, bvol, nnz, MPI_FLOAT, MPI_SUM, comm); 00149 00150 } 00151 00152 // calculate the norm of the backprojected volume 00153 for ( int j = 0 ; j < nnz ; ++j ) { 00154 bnorm += bvol[j] * (double) bvol[j]; 00155 grad[j] = bvol[j]; 00156 } 00157 bnorm /= nnz; 00158 bnorm = sqrt(bnorm); 00159 00160 } else { 00161 for ( int i = 0 ; i < nangloc ; ++i ) { 00162 // retrieve the angles and shifts from angleshift 00163 RA = Transform3D(EULER_SPIDER, angleshift[5*i + 0], 00164 angleshift[5*i + 1], angleshift[5*i + 2]); 00165 00166 // need to change signs here because the input shifts 00167 // are shifts associated with 2-D images. Because 00168 // the projection operator actually shifts the volume 00169 // the signs should be negated here 00170 dm[6] = -angleshift[5*i + 3]; 00171 dm[7] = -angleshift[5*i + 4]; 00172 for ( int ns = 1 ; ns < nsym + 1 ; ++ns ) { 00173 // iterate over symmetries 00174 Tf = Tf.get_sym(symmetry, ns) * RA; 00175 angdict = Tf.get_rotation(EULER_SPIDER); 00176 // reset the array in which projected data are stored 00177 for ( int j = 0 ; j < nx*nx ; ++j ) { 00178 projected_data[j] = 0.0; 00179 } 00180 phi = (float) angdict["phi"] * PI/180.0; 00181 theta = (float) angdict["theta"] * PI/180.0; 00182 psi = (float) angdict["psi"] * PI/180.0; 00183 make_proj_mat(phi, theta, psi, dm); 00184 // accumulate P^TPxvol in pxvol_loc 00185 ierr = fwdpj3(volsize, nrays, nnz, dm, origin, radius, ptrs, 00186 cord, xvol_sph, projected_data); 00187 ierr = bckpj3(volsize, nrays, nnz, dm, origin, radius, ptrs, 00188 cord, projected_data, pxvol_loc); 00189 } 00190 } 00191 // and reduce the accumulated pxvol_loc's 00192 ierr = MPI_Allreduce(pxvol_loc, pxvol, nnz, MPI_FLOAT, MPI_SUM, comm); 00193 00194 for ( int j = 0 ; j < nnz ; ++j ) { 00195 grad[j] = bvol[j]; 00196 grad[j] -= pxvol[j]; 00197 } 00198 } 00199 rnorm = 0.0; 00200 for ( int j = 0 ; j < nnz ; ++j ) { 00201 rnorm += grad[j]* (double) grad[j]; 00202 } 00203 rnorm /= nnz; 00204 rnorm = sqrt(rnorm); 00205 if ( mypid == 0 ) printf("iter = %3d, rnorm / bnorm = %11.3e, rnorm = %11.3e\n", 00206 iter, rnorm / bnorm, rnorm); 00207 // if on the second pass, rnorm is greater than bnorm, 00208 // lam is probably set too high reduce it by a factor of 2 and start over 00209 // if ( iter == 2 && rnorm / bnorm > old_rnorm ) { 00210 if ( rnorm / bnorm > old_rnorm ) { 00211 // but don't do it more than 20 times 00212 if ( restarts > 20 ) { 00213 if ( mypid == 0 ) 00214 printf("Failure to converge, even with lam = %11.3e\n", lam); 00215 break; 00216 } else { 00217 ++restarts; 00218 iter = 1; 00219 lam /= 2.0; 00220 // reset these 00221 // kluge, make sure if its 1.0 + epsilon it still works 00222 old_rnorm = 1.0001; 00223 for ( int j = 0 ; j < nnz ; ++j ) { 00224 xvol_sph[j] = 0.0; 00225 pxvol_loc[j] = 0.0; 00226 } 00227 if ( mypid == 0 ) printf("reducing lam to %11.3e, restarting\n", lam); 00228 continue; 00229 } 00230 } 00231 // if changes are sufficiently small, or if no further progress is made, terminate 00232 if ( rnorm / bnorm < tol || rnorm / bnorm > old_rnorm ) { 00233 if ( mypid == 0 ) 00234 printf("Terminating with rnorm/bnorm = %11.3e, tol = %11.3e, "); 00235 printf("old_rnorm = %11.3e\n", rnorm/bnorm, tol, old_rnorm); 00236 break; 00237 } 00238 // update the termination threshold 00239 old_rnorm = rnorm / bnorm; 00240 // update the reconstructed volume 00241 for ( int j = 0 ; j < nnz ; ++j ) { 00242 xvol_sph[j] += lam * grad[j]; 00243 // reset it so it's ready to accumulate for the next iteration 00244 pxvol_loc[j] = 0.0; 00245 } 00246 00247 ++iter; 00248 } 00249 if (mypid == 0) printf("Total time in SIRT = %11.3e\n", MPI_Wtime()-t0); 00250 00251 // unpack the spherical volume back out into the original EMData object 00252 ierr = sph2cb(xvol_sph, volsize, nrays, radius, nnz, ptrs, cord, voldata); 00253 00254 EMDeleteArray(grad); 00255 EMDeleteArray(pxvol); 00256 EMDeleteArray(bvol); 00257 EMDeleteArray(pxvol_loc); 00258 EMDeleteArray(bvol_loc); 00259 00260 EMDeleteArray(ptrs); 00261 EMDeleteArray(cord); 00262 EMDeleteArray(xvol_sph); 00263 EMDeleteArray(projected_data); 00264 00265 EMDeleteArray(psize); 00266 EMDeleteArray(nbase); 00267 00268 return 0; // recons3d_sirt_mpi 00269 }
|