00001
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036 #include "cmp.h"
00037 #include "emdata.h"
00038 #include "ctf.h"
00039 #include "plugins/cmp_template.h"
00040
00041 using namespace EMAN;
00042
00043 const string CccCmp::NAME = "ccc";
00044 const string SqEuclideanCmp::NAME = "sqeuclidean";
00045 const string DotCmp::NAME = "dot";
00046 const string TomoDotCmp::NAME = "dot.tomo";
00047 const string QuadMinDotCmp::NAME = "quadmindot";
00048 const string OptVarianceCmp::NAME = "optvariance";
00049 const string PhaseCmp::NAME = "phase";
00050 const string FRCCmp::NAME = "frc";
00051
00052 template <> Factory < Cmp >::Factory()
00053 {
00054 force_add<CccCmp>();
00055 force_add<SqEuclideanCmp>();
00056 force_add<DotCmp>();
00057 force_add<TomoDotCmp>();
00058 force_add<QuadMinDotCmp>();
00059 force_add<OptVarianceCmp>();
00060 force_add<PhaseCmp>();
00061 force_add<FRCCmp>();
00062
00063 }
00064
00065 void Cmp::validate_input_args(const EMData * image, const EMData *with) const
00066 {
00067 if (!image) {
00068 throw NullPointerException("compared image");
00069 }
00070 if (!with) {
00071 throw NullPointerException("compare-with image");
00072 }
00073
00074 if (!EMUtil::is_same_size(image, with)) {
00075 throw ImageFormatException( "images not same size");
00076 }
00077
00078 float *d1 = image->get_data();
00079 if (!d1) {
00080 throw NullPointerException("image contains no data");
00081 }
00082
00083 float *d2 = with->get_data();
00084 if (!d2) {
00085 throw NullPointerException("compare-with image data");
00086 }
00087 }
00088
00089
00090 float CccCmp::cmp(EMData * image, EMData *with) const
00091 {
00092 ENTERFUNC;
00093 if (image->is_complex() || with->is_complex())
00094 throw ImageFormatException( "Complex images not supported by CMP::CccCmp");
00095 validate_input_args(image, with);
00096
00097 const float *const d1 = image->get_const_data();
00098 const float *const d2 = with->get_const_data();
00099
00100 float negative = (float)params.set_default("negative", 1);
00101 if (negative) negative=-1.0; else negative=1.0;
00102
00103 double avg1 = 0.0, var1 = 0.0, avg2 = 0.0, var2 = 0.0, ccc = 0.0;
00104 long n = 0;
00105 size_t totsize = image->get_xsize()*image->get_ysize()*image->get_zsize();
00106
00107 bool has_mask = false;
00108 EMData* mask = 0;
00109 if (params.has_key("mask")) {
00110 mask = params["mask"];
00111 if(mask!=0) {has_mask=true;}
00112 }
00113
00114 if (has_mask) {
00115 const float *const dm = mask->get_const_data();
00116 for (size_t i = 0; i < totsize; ++i) {
00117 if (dm[i] > 0.5) {
00118 avg1 += double(d1[i]);
00119 var1 += d1[i]*double(d1[i]);
00120 avg2 += double(d2[i]);
00121 var2 += d2[i]*double(d2[i]);
00122 ccc += d1[i]*double(d2[i]);
00123 n++;
00124 }
00125 }
00126 } else {
00127 for (size_t i = 0; i < totsize; ++i) {
00128 avg1 += double(d1[i]);
00129 var1 += d1[i]*double(d1[i]);
00130 avg2 += double(d2[i]);
00131 var2 += d2[i]*double(d2[i]);
00132 ccc += d1[i]*double(d2[i]);
00133 }
00134 n = totsize;
00135 }
00136
00137 avg1 /= double(n);
00138 var1 = var1/double(n) - avg1*avg1;
00139 avg2 /= double(n);
00140 var2 = var2/double(n) - avg2*avg2;
00141 ccc = ccc/double(n) - avg1*avg2;
00142 ccc /= sqrt(var1*var2);
00143 ccc *= negative;
00144 return static_cast<float>(ccc);
00145 EXITFUNC;
00146 }
00147
00148
00149
00150
00151 float SqEuclideanCmp::cmp(EMData *image,EMData * withorig ) const
00152 {
00153 ENTERFUNC;
00154 EMData *with = withorig;
00155 validate_input_args(image, with);
00156
00157 int zeromask = params.set_default("zeromask",0);
00158 int normto = params.set_default("normto",0);
00159
00160 if (normto) {
00161 with = withorig->process("normalize.toimage",Dict("to",image));
00162 with->set_attr("deleteme",1);
00163 if ((float)(with->get_attr("norm_mult"))<=0) {
00164 delete with;
00165 with=withorig;
00166 }
00167 }
00168
00169 const float *const y_data = with->get_const_data();
00170 const float *const x_data = image->get_const_data();
00171 double result = 0.;
00172 float n = 0.0f;
00173 if(image->is_complex() && with->is_complex()) {
00174
00175 int nx = with->get_xsize();
00176 int ny = with->get_ysize();
00177 int nz = with->get_zsize();
00178 nx = (nx - 2 + with->is_fftodd());
00179 int lsd2 = (nx + 2 - nx%2) ;
00180
00181 int ixb = 2*((nx+1)%2);
00182 int iyb = ny%2;
00183
00184 if(nz == 1) {
00185
00186 for ( int iz = 0; iz <= nz-1; iz++) {
00187 double part = 0.;
00188 for ( int iy = 0; iy <= ny-1; iy++) {
00189 for ( int ix = 2; ix <= lsd2 - 1 - ixb; ix++) {
00190 size_t ii = ix + (iy + iz * ny)* lsd2;
00191 part += (x_data[ii] - y_data[ii])*double(x_data[ii] - y_data[ii]);
00192 }
00193 }
00194 for ( int iy = 1; iy <= ny/2-1 + iyb; iy++) {
00195 size_t ii = (iy + iz * ny)* lsd2;
00196 part += (x_data[ii] - y_data[ii])*double(x_data[ii] - y_data[ii]);
00197 part += (x_data[ii+1] - y_data[ii+1])*double(x_data[ii+1] - y_data[ii+1]);
00198 }
00199 if(nx%2 == 0) {
00200 for ( int iy = 1; iy <= ny/2-1 + iyb; iy++) {
00201 size_t ii = lsd2 - 2 + (iy + iz * ny)* lsd2;
00202 part += (x_data[ii] - y_data[ii])*double(x_data[ii] - y_data[ii]);
00203 part += (x_data[ii+1] - y_data[ii+1])*double(x_data[ii+1] - y_data[ii+1]);
00204 }
00205
00206 }
00207 part *= 2;
00208 part += (x_data[0] - y_data[0])*double(x_data[0] - y_data[0]);
00209 if(ny%2 == 0) {
00210 int ii = (ny/2 + iz * ny)* lsd2;
00211 part += (x_data[ii] - y_data[ii])*double(x_data[ii] - y_data[ii]);
00212 }
00213 if(nx%2 == 0) {
00214 int ii = lsd2 - 2 + (0 + iz * ny)* lsd2;
00215 part += (x_data[ii] - y_data[ii])*double(x_data[ii] - y_data[ii]);
00216 if(ny%2 == 0) {
00217 int ii = lsd2 - 2 +(ny/2 + iz * ny)* lsd2;
00218 part += (x_data[ii] - y_data[ii])*double(x_data[ii] - y_data[ii]);
00219 }
00220 }
00221 result += part;
00222 }
00223 n = (float)nx*(float)ny*(float)nz*(float)nx*(float)ny*(float)nz;
00224
00225 }else{
00226 int ky, kz;
00227 int ny2 = ny/2; int nz2 = nz/2;
00228 for ( int iz = 0; iz <= nz-1; iz++) {
00229 if(iz>nz2) kz=iz-nz; else kz=iz;
00230 for ( int iy = 0; iy <= ny-1; iy++) {
00231 if(iy>ny2) ky=iy-ny; else ky=iy;
00232 for ( int ix = 0; ix <= lsd2-1; ix++) {
00233
00234 if(ix>0 || (kz>=0 && (ky>=0 || kz!=0))) {
00235 size_t ii = ix + (iy + iz * ny)* lsd2;
00236 result += (x_data[ii] - y_data[ii])*double(x_data[ii] - y_data[ii]);
00237 }
00238 }
00239 }
00240 }
00241 n = ((float)nx*(float)ny*(float)nz*(float)nx*(float)ny*(float)nz)/2.0f;
00242 }
00243 } else {
00244 size_t totsize = image->get_xsize()*image->get_ysize()*image->get_zsize();
00245 if (params.has_key("mask")) {
00246 EMData* mask;
00247 mask = params["mask"];
00248 const float *const dm = mask->get_const_data();
00249 for (size_t i = 0; i < totsize; i++) {
00250 if (dm[i] > 0.5) {
00251 double temp = x_data[i]- y_data[i];
00252 result += temp*temp;
00253 n++;
00254 }
00255 }
00256 }
00257 else if (zeromask) {
00258 n=0;
00259 for (size_t i = 0; i < totsize; i++) {
00260 if (x_data[i]==0 || y_data[i]==0) continue;
00261 double temp = x_data[i]- y_data[i];
00262 result += temp*temp;
00263 n++;
00264 }
00265
00266 }
00267 else {
00268 for (size_t i = 0; i < totsize; i++) {
00269 double temp = x_data[i]- y_data[i];
00270 result += temp*temp;
00271 }
00272 n = (float)totsize;
00273 }
00274 }
00275 result/=n;
00276
00277 EXITFUNC;
00278 if (with->has_attr("deleteme")) delete with;
00279 return static_cast<float>(result);
00280 }
00281
00282
00283
00284
00285 float DotCmp::cmp(EMData* image, EMData* with) const
00286 {
00287 ENTERFUNC;
00288 validate_input_args(image, with);
00289
00290 const float *const x_data = image->get_const_data();
00291 const float *const y_data = with->get_const_data();
00292
00293 int normalize = params.set_default("normalize", 0);
00294 float negative = (float)params.set_default("negative", 1);
00295
00296 if (negative) negative=-1.0; else negative=1.0;
00297 double result = 0.;
00298 long n = 0;
00299 if(image->is_complex() && with->is_complex()) {
00300
00301 int nx = with->get_xsize();
00302 int ny = with->get_ysize();
00303 int nz = with->get_zsize();
00304 nx = (nx - 2 + with->is_fftodd());
00305 int lsd2 = (nx + 2 - nx%2) ;
00306
00307 int ixb = 2*((nx+1)%2);
00308 int iyb = ny%2;
00309
00310 if(nz == 1) {
00311
00312 for ( int iz = 0; iz <= nz-1; ++iz) {
00313 double part = 0.;
00314 for ( int iy = 0; iy <= ny-1; ++iy) {
00315 for ( int ix = 2; ix <= lsd2 - 1 - ixb; ++ix) {
00316 size_t ii = ix + (iy + iz * ny)* lsd2;
00317 part += x_data[ii] * double(y_data[ii]);
00318 }
00319 }
00320 for ( int iy = 1; iy <= ny/2-1 + iyb; ++iy) {
00321 size_t ii = (iy + iz * ny)* lsd2;
00322 part += x_data[ii] * double(y_data[ii]);
00323 part += x_data[ii+1] * double(y_data[ii+1]);
00324 }
00325 if(nx%2 == 0) {
00326 for ( int iy = 1; iy <= ny/2-1 + iyb; ++iy) {
00327 size_t ii = lsd2 - 2 + (iy + iz * ny)* lsd2;
00328 part += x_data[ii] * double(y_data[ii]);
00329 part += x_data[ii+1] * double(y_data[ii+1]);
00330 }
00331
00332 }
00333 part *= 2;
00334 part += x_data[0] * double(y_data[0]);
00335 if(ny%2 == 0) {
00336 size_t ii = (ny/2 + iz * ny)* lsd2;
00337 part += x_data[ii] * double(y_data[ii]);
00338 }
00339 if(nx%2 == 0) {
00340 size_t ii = lsd2 - 2 + (0 + iz * ny)* lsd2;
00341 part += x_data[ii] * double(y_data[ii]);
00342 if(ny%2 == 0) {
00343 int ii = lsd2 - 2 +(ny/2 + iz * ny)* lsd2;
00344 part += x_data[ii] * double(y_data[ii]);
00345 }
00346 }
00347 result += part;
00348 }
00349 if( normalize ) {
00350
00351 double square_sum1 = 0., square_sum2 = 0.;
00352 for ( int iz = 0; iz <= nz-1; ++iz) {
00353 for ( int iy = 0; iy <= ny-1; ++iy) {
00354 for ( int ix = 2; ix <= lsd2 - 1 - ixb; ++ix) {
00355 size_t ii = ix + (iy + iz * ny)* lsd2;
00356 square_sum1 += x_data[ii] * double(x_data[ii]);
00357 square_sum2 += y_data[ii] * double(y_data[ii]);
00358 }
00359 }
00360 for ( int iy = 1; iy <= ny/2-1 + iyb; ++iy) {
00361 size_t ii = (iy + iz * ny)* lsd2;
00362 square_sum1 += x_data[ii] * double(x_data[ii]);
00363 square_sum1 += x_data[ii+1] * double(x_data[ii+1]);
00364 square_sum2 += y_data[ii] * double(y_data[ii]);
00365 square_sum2 += y_data[ii+1] * double(y_data[ii+1]);
00366 }
00367 if(nx%2 == 0) {
00368 for ( int iy = 1; iy <= ny/2-1 + iyb; ++iy) {
00369 size_t ii = lsd2 - 2 + (iy + iz * ny)* lsd2;
00370 square_sum1 += x_data[ii] * double(x_data[ii]);
00371 square_sum1 += x_data[ii+1] * double(x_data[ii+1]);
00372 square_sum2 += y_data[ii] * double(y_data[ii]);
00373 square_sum2 += y_data[ii+1] * double(y_data[ii+1]);
00374 }
00375
00376 }
00377 square_sum1 *= 2;
00378 square_sum1 += x_data[0] * double(x_data[0]);
00379 square_sum2 *= 2;
00380 square_sum2 += y_data[0] * double(y_data[0]);
00381 if(ny%2 == 0) {
00382 int ii = (ny/2 + iz * ny)* lsd2;
00383 square_sum1 += x_data[ii] * double(x_data[ii]);
00384 square_sum2 += y_data[ii] * double(y_data[ii]);
00385 }
00386 if(nx%2 == 0) {
00387 int ii = lsd2 - 2 + (0 + iz * ny)* lsd2;
00388 square_sum1 += x_data[ii] * double(x_data[ii]);
00389 square_sum2 += y_data[ii] * double(y_data[ii]);
00390 if(ny%2 == 0) {
00391 int ii = lsd2 - 2 +(ny/2 + iz * ny)* lsd2;
00392 square_sum1 += x_data[ii] * double(x_data[ii]);
00393 square_sum2 += y_data[ii] * double(y_data[ii]);
00394 }
00395 }
00396 }
00397 result /= sqrt(square_sum1*square_sum2);
00398 } else result /= ((float)nx*(float)ny*(float)nz*(float)nx*(float)ny*(float)nz);
00399
00400 } else {
00401 int ky, kz;
00402 int ny2 = ny/2; int nz2 = nz/2;
00403 for ( int iz = 0; iz <= nz-1; ++iz) {
00404 if(iz>nz2) kz=iz-nz; else kz=iz;
00405 for ( int iy = 0; iy <= ny-1; ++iy) {
00406 if(iy>ny2) ky=iy-ny; else ky=iy;
00407 for ( int ix = 0; ix <= lsd2-1; ++ix) {
00408
00409 if(ix>0 || (kz>=0 && (ky>=0 || kz!=0))) {
00410 size_t ii = ix + (iy + iz * ny)* lsd2;
00411 result += x_data[ii] * double(y_data[ii]);
00412 }
00413 }
00414 }
00415 }
00416 if( normalize ) {
00417
00418 double square_sum1 = 0., square_sum2 = 0.;
00419 int ky, kz;
00420 int ny2 = ny/2; int nz2 = nz/2;
00421 for ( int iz = 0; iz <= nz-1; ++iz) {
00422 if(iz>nz2) kz=iz-nz; else kz=iz;
00423 for ( int iy = 0; iy <= ny-1; ++iy) {
00424 if(iy>ny2) ky=iy-ny; else ky=iy;
00425 for ( int ix = 0; ix <= lsd2-1; ++ix) {
00426
00427 if(ix>0 || (kz>=0 && (ky>=0 || kz!=0))) {
00428 size_t ii = ix + (iy + iz * ny)* lsd2;
00429 square_sum1 += x_data[ii] * double(x_data[ii]);
00430 square_sum2 += y_data[ii] * double(y_data[ii]);
00431 }
00432 }
00433 }
00434 }
00435 result /= sqrt(square_sum1*square_sum2);
00436 } else result /= ((float)nx*(float)ny*(float)nz*(float)nx*(float)ny*(float)nz/2);
00437 }
00438 } else {
00439 size_t totsize = image->get_xsize() * image->get_ysize() * image->get_zsize();
00440
00441 double square_sum1 = 0., square_sum2 = 0.;
00442
00443 if (params.has_key("mask")) {
00444 EMData* mask;
00445 mask = params["mask"];
00446 const float *const dm = mask->get_const_data();
00447 if (normalize) {
00448 for (size_t i = 0; i < totsize; i++) {
00449 if (dm[i] > 0.5) {
00450 square_sum1 += x_data[i]*double(x_data[i]);
00451 square_sum2 += y_data[i]*double(y_data[i]);
00452 result += x_data[i]*double(y_data[i]);
00453 }
00454 }
00455 } else {
00456 for (size_t i = 0; i < totsize; i++) {
00457 if (dm[i] > 0.5) {
00458 result += x_data[i]*double(y_data[i]);
00459 n++;
00460 }
00461 }
00462 }
00463 } else {
00464
00465 for (size_t i=0; i<totsize; i++) result+=x_data[i]*y_data[i];
00466
00467 if (normalize) {
00468 square_sum1 = image->get_attr("square_sum");
00469 square_sum2 = with->get_attr("square_sum");
00470 } else n = totsize;
00471 }
00472 if (normalize) result /= (sqrt(square_sum1*square_sum2)); else result /= n;
00473 }
00474
00475
00476 EXITFUNC;
00477 return (float) (negative*result);
00478 }
00479
00480
00481 float TomoDotCmp::cmp(EMData * image, EMData *with) const
00482 {
00483 ENTERFUNC;
00484 float threshold = params.set_default("threshold",0.f);
00485 if (threshold < 0.0f) throw InvalidParameterException("The threshold parameter must be greater than or equal to zero");
00486
00487 if ( threshold > 0) {
00488 EMData* ccf = params.set_default("ccf",(EMData*) NULL);
00489 bool ccf_ownership = false;
00490 if (!ccf) {
00491 ccf = image->calc_ccf(with);
00492 ccf_ownership = true;
00493 }
00494 bool norm = params.set_default("norm",false);
00495 if (norm) ccf->process_inplace("normalize");
00496 int tx = params.set_default("tx",0); int ty = params.set_default("ty",0); int tz = params.set_default("tz",0);
00497 float best_score = ccf->get_value_at_wrap(tx,ty,tz)/static_cast<float>(image->get_size());
00498 EMData* ccf_fft = ccf->do_fft();
00499 if (ccf_ownership) delete ccf; ccf = 0;
00500 ccf_fft->process_inplace("threshold.binary.fourier",Dict("value",threshold));
00501 float map_sum = ccf_fft->get_attr("mean");
00502 if (map_sum == 0.0f) throw UnexpectedBehaviorException("The number of voxels in the Fourier image with an amplitude above your threshold is zero. Please adjust your parameters");
00503 best_score /= map_sum;
00504 delete ccf_fft; ccf_fft = 0;
00505 return -best_score;
00506 } else {
00507 return -image->dot(with);
00508 }
00509
00510
00511 }
00512
00513
00514
00515 float QuadMinDotCmp::cmp(EMData * image, EMData *with) const
00516 {
00517 ENTERFUNC;
00518 validate_input_args(image, with);
00519
00520 if (image->get_zsize()!=1) throw InvalidValueException(0, "QuadMinDotCmp supports 2D only");
00521
00522 int nx=image->get_xsize();
00523 int ny=image->get_ysize();
00524
00525 int normalize = params.set_default("normalize", 0);
00526 float negative = (float)params.set_default("negative", 1);
00527
00528 if (negative) negative=-1.0; else negative=1.0;
00529
00530 double result[4] = { 0,0,0,0 }, sq1[4] = { 0,0,0,0 }, sq2[4] = { 0,0,0,0 } ;
00531
00532 vector<int> image_saved_offsets = image->get_array_offsets();
00533 vector<int> with_saved_offsets = with->get_array_offsets();
00534 image->set_array_offsets(-nx/2,-ny/2);
00535 with->set_array_offsets(-nx/2,-ny/2);
00536 int i,x,y;
00537 for (y=-ny/2; y<ny/2; y++) {
00538 for (x=-nx/2; x<nx/2; x++) {
00539 int quad=(x<0?0:1) + (y<0?0:2);
00540 result[quad]+=(*image)(x,y)*(*with)(x,y);
00541 if (normalize) {
00542 sq1[quad]+=(*image)(x,y)*(*image)(x,y);
00543 sq2[quad]+=(*with)(x,y)*(*with)(x,y);
00544 }
00545 }
00546 }
00547 image->set_array_offsets(image_saved_offsets);
00548 with->set_array_offsets(with_saved_offsets);
00549
00550 if (normalize) {
00551 for (i=0; i<4; i++) result[i]/=sqrt(sq1[i]*sq2[i]);
00552 } else {
00553 for (i=0; i<4; i++) result[i]/=nx*ny/4;
00554 }
00555
00556 float worst=static_cast<float>(result[0]);
00557 for (i=1; i<4; i++) if (static_cast<float>(result[i])<worst) worst=static_cast<float>(result[i]);
00558
00559 EXITFUNC;
00560 return (float) (negative*worst);
00561 }
00562
00563 float OptVarianceCmp::cmp(EMData * image, EMData *with) const
00564 {
00565 ENTERFUNC;
00566 validate_input_args(image, with);
00567
00568 int keepzero = params.set_default("keepzero", 1);
00569 int invert = params.set_default("invert",0);
00570 int matchfilt = params.set_default("matchfilt",1);
00571 int matchamp = params.set_default("matchamp",0);
00572 int radweight = params.set_default("radweight",0);
00573 int dbug = params.set_default("debug",0);
00574
00575 size_t size = image->get_xsize() * image->get_ysize() * image->get_zsize();
00576
00577
00578 EMData *with2=NULL;
00579 if (matchfilt) {
00580 EMData *a = image->do_fft();
00581 EMData *b = with->do_fft();
00582
00583 vector <float> rfa=a->calc_radial_dist(a->get_ysize()/2,0.0f,1.0f,1);
00584 vector <float> rfb=b->calc_radial_dist(b->get_ysize()/2,0.0f,1.0f,1);
00585
00586 float avg=0;
00587 for (size_t i=0; i<a->get_ysize()/2.0f; i++) {
00588 rfa[i]=(rfb[i]==0?0.0f:(rfa[i]/rfb[i]));
00589 avg+=rfa[i];
00590 }
00591
00592 avg/=a->get_ysize()/2.0f;
00593 for (size_t i=0; i<a->get_ysize()/2.0f; i++) {
00594 if (rfa[i]>avg*10.0) rfa[i]=10.0;
00595 }
00596 rfa[0]=0.0;
00597
00598 if (dbug) b->write_image("a.hdf",-1);
00599
00600 b->apply_radial_func(0.0f,1.0f/a->get_ysize(),rfa);
00601 with2=b->do_ift();
00602
00603 if (dbug) b->write_image("a.hdf",-1);
00604 if (dbug) a->write_image("a.hdf",-1);
00605
00606
00607
00608
00609
00610
00611
00612
00613
00614
00615
00616
00617 delete a;
00618 delete b;
00619
00620 if (dbug) {
00621 with2->write_image("a.hdf",-1);
00622 image->write_image("a.hdf",-1);
00623 }
00624
00625
00626
00627 }
00628
00629
00630
00631 if (matchamp) {
00632 EMData *a = image->do_fft();
00633 EMData *b = with->do_fft();
00634 size_t size2 = a->get_xsize() * a->get_ysize() * a->get_zsize();
00635
00636 a->ri2ap();
00637 b->ri2ap();
00638
00639 const float *const ad=a->get_const_data();
00640 float * bd=b->get_data();
00641
00642 for (size_t i=0; i<size2; i+=2) bd[i]=ad[i];
00643 b->update();
00644
00645 b->ap2ri();
00646 with2=b->do_ift();
00647
00648 delete a;
00649 delete b;
00650 }
00651
00652 const float * x_data;
00653 if (with2) x_data=with2->get_const_data();
00654 else x_data = with->get_const_data();
00655 const float *const y_data = image->get_const_data();
00656
00657 size_t nx = image->get_xsize();
00658 float m = 0;
00659 float b = 0;
00660
00661
00662
00663 if (dbug) {
00664 FILE *out=fopen("dbug.optvar.txt","w");
00665 if (out) {
00666 for (size_t i=0; i<size; i++) {
00667 if ( !keepzero || (x_data[i] && y_data[i])) fprintf(out,"%g\t%g\n",x_data[i],y_data[i]);
00668 }
00669 fclose(out);
00670 }
00671 }
00672
00673
00674 Util::calc_least_square_fit(size, x_data, y_data, &m, &b, keepzero);
00675 if (m == 0) {
00676 m = FLT_MIN;
00677 }
00678 b = -b / m;
00679 m = 1.0f / m;
00680
00681
00682
00683
00684
00685
00686
00687
00688 double result = 0;
00689 int count = 0;
00690
00691 if (radweight) {
00692 if (image->get_zsize()!=1) throw ImageDimensionException("radweight option is 2D only");
00693 if (keepzero) {
00694 for (size_t i = 0,y=0; i < size; y++) {
00695 for (size_t x=0; x<nx; i++,x++) {
00696 if (y_data[i] && x_data[i]) {
00697 #ifdef _WIN32
00698 if (invert) result += Util::square(x_data[i] - (y_data[i]-b)/m)*(_hypot((float)x,(float)y)+nx/4.0);
00699 else result += Util::square((x_data[i] * m) + b - y_data[i])*(_hypot((float)x,(float)y)+nx/4.0);
00700 #else
00701 if (invert) result += Util::square(x_data[i] - (y_data[i]-b)/m)*(hypot((float)x,(float)y)+nx/4.0);
00702 else result += Util::square((x_data[i] * m) + b - y_data[i])*(hypot((float)x,(float)y)+nx/4.0);
00703 #endif
00704 count++;
00705 }
00706 }
00707 }
00708 result/=count;
00709 }
00710 else {
00711 for (size_t i = 0,y=0; i < size; y++) {
00712 for (size_t x=0; x<nx; i++,x++) {
00713 #ifdef _WIN32
00714 if (invert) result += Util::square(x_data[i] - (y_data[i]-b)/m)*(_hypot((float)x,(float)y)+nx/4.0);
00715 else result += Util::square((x_data[i] * m) + b - y_data[i])*(_hypot((float)x,(float)y)+nx/4.0);
00716 #else
00717 if (invert) result += Util::square(x_data[i] - (y_data[i]-b)/m)*(hypot((float)x,(float)y)+nx/4.0);
00718 else result += Util::square((x_data[i] * m) + b - y_data[i])*(hypot((float)x,(float)y)+nx/4.0);
00719 #endif
00720 }
00721 }
00722 result = result / size;
00723 }
00724 }
00725 else {
00726 if (keepzero) {
00727 for (size_t i = 0; i < size; i++) {
00728 if (y_data[i] && x_data[i]) {
00729 if (invert) result += Util::square(x_data[i] - (y_data[i]-b)/m);
00730 else result += Util::square((x_data[i] * m) + b - y_data[i]);
00731 count++;
00732 }
00733 }
00734 result/=count;
00735 }
00736 else {
00737 for (size_t i = 0; i < size; i++) {
00738 if (invert) result += Util::square(x_data[i] - (y_data[i]-b)/m);
00739 else result += Util::square((x_data[i] * m) + b - y_data[i]);
00740 }
00741 result = result / size;
00742 }
00743 }
00744 scale = m;
00745 shift = b;
00746
00747 image->set_attr("ovcmp_m",m);
00748 image->set_attr("ovcmp_b",b);
00749 if (with2) delete with2;
00750 EXITFUNC;
00751
00752 #if 0
00753 return (1 - result);
00754 #endif
00755
00756 return static_cast<float>(result);
00757 }
00758
00759 float PhaseCmp::cmp(EMData * image, EMData *with) const
00760 {
00761 ENTERFUNC;
00762
00763 int snrweight = params.set_default("snrweight", 0);
00764 int snrfn = params.set_default("snrfn",0);
00765 int ampweight = params.set_default("ampweight",0);
00766 int zeromask = params.set_default("zeromask",0);
00767 float minres = params.set_default("minres",500.0f);
00768 float maxres = params.set_default("maxres",10.0f);
00769
00770 if (snrweight && snrfn) throw InvalidCallException("SNR weight and SNRfn cannot both be set in the phase comparator");
00771
00772 #ifdef EMAN2_USING_CUDA
00773 if (image->gpu_operation_preferred()) {
00774
00775 EXITFUNC;
00776 return cuda_cmp(image,with);
00777 }
00778 #endif
00779
00780 EMData *image_fft = NULL;
00781 EMData *with_fft = NULL;
00782
00783 int ny = image->get_ysize();
00784
00785 int np = 0;
00786 vector<float> snr;
00787
00788
00789 if (snrweight) {
00790 Ctf *ctf = NULL;
00791 if (!image->has_attr("ctf")) {
00792 if (!with->has_attr("ctf")) throw InvalidCallException("SNR weight with no CTF parameters");
00793 ctf=with->get_attr("ctf");
00794 }
00795 else ctf=image->get_attr("ctf");
00796
00797 float ds=1.0f/(ctf->apix*ny);
00798 snr=ctf->compute_1d(ny,ds,Ctf::CTF_SNR);
00799 if(ctf) {delete ctf; ctf=0;}
00800 np=snr.size();
00801 }
00802
00803 else if (snrfn==1) {
00804 np=ny/2;
00805 snr.resize(np);
00806
00807 for (int i = 0; i < np; i++) {
00808 float x2 = 10.0f*i/np;
00809 snr[i] = x2 * exp(-x2);
00810 }
00811 }
00812
00813 else {
00814 np=ny/2;
00815 snr.resize(np);
00816
00817 for (int i = 0; i < np; i++) snr[i]=1.0;
00818 }
00819
00820
00821 float pmin,pmax;
00822 if (minres>0) pmin=((float)image->get_attr("apix_x")*image->get_ysize())/minres;
00823 else pmin=0;
00824 if (maxres>0) pmax=((float)image->get_attr("apix_x")*image->get_ysize())/maxres;
00825 else pmax=0;
00826
00827
00828
00829
00830 for (int i = 0; i < np; i++) {
00831 if (pmin>0) snr[i]*=(tanh(5.0f*(i-pmin)/pmin)+1.0f)/2.0f;
00832 if (pmax>0) snr[i]*=(1.0f-tanh(i-pmax))/2.0f;
00833
00834 }
00835
00836 if (zeromask) {
00837 image_fft=image->copy();
00838 with_fft=with->copy();
00839
00840 if (image_fft->is_complex()) image_fft->do_ift_inplace();
00841 if (with_fft->is_complex()) with_fft->do_ift_inplace();
00842
00843 int sz=image_fft->get_xsize()*image_fft->get_ysize()*image_fft->get_zsize();
00844 float *d1=image_fft->get_data();
00845 float *d2=with_fft->get_data();
00846
00847 for (int i=0; i<sz; i++) {
00848 if (d1[i]==0.0 || d2[i]==0.0) { d1[i]=0.0; d2[i]=0.0; }
00849 }
00850
00851 image_fft->update();
00852 with_fft->update();
00853 image_fft->do_fft_inplace();
00854 with_fft->do_fft_inplace();
00855 image_fft->set_attr("free_me",1);
00856 with_fft->set_attr("free_me",1);
00857 }
00858 else {
00859 if (image->is_complex()) image_fft=image;
00860 else {
00861 image_fft=image->do_fft();
00862 image_fft->set_attr("free_me",1);
00863 }
00864
00865 if (with->is_complex()) with_fft=with;
00866 else {
00867 with_fft=with->do_fft();
00868 with_fft->set_attr("free_me",1);
00869 }
00870 }
00871
00872
00873
00874 const float *const image_fft_data = image_fft->get_const_data();
00875 const float *const with_fft_data = with_fft->get_const_data();
00876 double sum = 0;
00877 double norm = FLT_MIN;
00878 size_t i = 0;
00879 int nx=image_fft->get_xsize();
00880 ny=image_fft->get_ysize();
00881 int nz=image_fft->get_zsize();
00882 int nx2=image_fft->get_xsize()/2;
00883 int ny2=image_fft->get_ysize()/2;
00884 int nz2=image_fft->get_zsize()/2;
00885
00886
00887 if (np==0) {
00888 for (int z = 0; z < nz; z++){
00889 for (int y = 0; y < ny; y++) {
00890 for (int x = 0; x < nx2; x ++) {
00891 float a;
00892 if (ampweight) a= (float)hypot(with_fft_data[i],with_fft_data[i+1]);
00893 else a=1.0f;
00894 sum += Util::angle_err_ri(image_fft_data[i],image_fft_data[i+1],with_fft_data[i],with_fft_data[i+1]) * a;
00895 norm += a;
00896 i += 2;
00897 }
00898 }
00899 }
00900
00901 }
00902 else if (nz==1) {
00903 for (int y = 0; y < ny; y++) {
00904 for (int x = 0; x < nx2; x ++) {
00905 int r=Util::hypot_fast_int(x,y>ny/2?ny-y:y);
00906 if (r>=ny2) { i+=2; continue; }
00907
00908 float a;
00909 if (ampweight) a= (float)hypot(with_fft_data[i],with_fft_data[i+1]);
00910 else a=1.0f;
00911 a*=snr[r];
00912 sum += Util::angle_err_ri(image_fft_data[i],image_fft_data[i+1],with_fft_data[i],with_fft_data[i+1]) * a;
00913 norm += a;
00914 i += 2;
00915 }
00916 }
00917 }
00918 else {
00919 for (int z = 0; z < nz; z++){
00920 for (int y = 0; y < ny; y++) {
00921 for (int x = 0; x < nx2; x ++) {
00922 int r=(int)Util::hypot3(x,y>ny/2?ny-y:y,z>nz/2?nz-z:z);
00923 if (r>=ny2) { i+=2; continue; }
00924
00925 float a;
00926 if (ampweight) a= (float)hypot(with_fft_data[i],with_fft_data[i+1]);
00927 else a=1.0f;
00928 a*=snr[r];
00929 sum += Util::angle_err_ri(image_fft_data[i],image_fft_data[i+1],with_fft_data[i],with_fft_data[i+1]) * a;
00930 norm += a;
00931 i += 2;
00932 }
00933 }
00934 }
00935
00936 }
00937
00938 EXITFUNC;
00939
00940 if( image_fft->has_attr("free_me") )
00941 {
00942 delete image_fft;
00943 image_fft = 0;
00944 }
00945 if( with_fft->has_attr("free_me") )
00946 {
00947 delete with_fft;
00948 with_fft = 0;
00949 }
00950 #if 0
00951 return (1.0f - sum / norm);
00952 #endif
00953 return (float)(sum / norm);
00954 }
00955
00956 #ifdef EMAN2_USING_CUDA
00957 #include "cuda/cuda_cmp.h"
00958 float PhaseCmp::cuda_cmp(EMData * image, EMData *with) const
00959 {
00960 ENTERFUNC;
00961 validate_input_args(image, with);
00962
00963 typedef vector<EMData*> EMDatas;
00964 static EMDatas hist_pyramid;
00965 static EMDatas norm_pyramid;
00966 static EMData weighting;
00967 static int image_size = 0;
00968
00969 int size;
00970 EMData::CudaDataLock imagelock(image);
00971 EMData::CudaDataLock withlock(with);
00972
00973 if (image->is_complex()) {
00974 size = image->get_xsize();
00975 } else {
00976 int nx = image->get_xsize()+2;
00977 nx -= nx%2;
00978 size = nx*image->get_ysize()*image->get_zsize();
00979 }
00980 if (size != image_size) {
00981 for(unsigned int i =0; i < hist_pyramid.size(); ++i) {
00982 delete hist_pyramid[i];
00983 delete norm_pyramid[i];
00984 }
00985 hist_pyramid.clear();
00986 norm_pyramid.clear();
00987 int s = size;
00988 if (s < 1) throw UnexpectedBehaviorException("The image is 0 size");
00989 int p2 = 1;
00990 while ( s != 1 ) {
00991 s /= 2;
00992 p2 *= 2;
00993 }
00994 if ( p2 != size ) {
00995 p2 *= 2;
00996 s = p2;
00997 }
00998 if (s != 1) s /= 2;
00999 while (true) {
01000 EMData* h = new EMData();
01001 h->set_size_cuda(s); h->to_value(0.0);
01002 hist_pyramid.push_back(h);
01003 EMData* n = new EMData();
01004 n->set_size_cuda(s); n->to_value(0.0);
01005 norm_pyramid.push_back(n);
01006 if ( s == 1) break;
01007 s /= 2;
01008 }
01009 int nx = image->get_xsize()+2;
01010 nx -= nx%2;
01011 int ny = image->get_ysize();
01012 int nz = image->get_zsize();
01013 weighting.set_size_cuda(nx,ny,nz);
01014
01015 weighting.set_size_cuda(nx/2,ny,nz);
01016 float np = (int) ceil(Ctf::CTFOS * sqrt(2.0f) * ny / 2) + 2;
01017 EMDataForCuda tmp = weighting.get_data_struct_for_cuda();
01018 calc_phase_weights_cuda(&tmp,np);
01019
01020 image_size = size;
01021 }
01022
01023 EMDataForCuda hist[hist_pyramid.size()];
01024 EMDataForCuda norm[hist_pyramid.size()];
01025
01026 EMDataForCuda wt = weighting.get_data_struct_for_cuda();
01027 EMData::CudaDataLock lock1(&weighting);
01028 for(unsigned int i = 0; i < hist_pyramid.size(); ++i ) {
01029 hist[i] = hist_pyramid[i]->get_data_struct_for_cuda();
01030 hist_pyramid[i]->cuda_lock();
01031 norm[i] = norm_pyramid[i]->get_data_struct_for_cuda();
01032 norm_pyramid[i]->cuda_lock();
01033 }
01034
01035 EMData *image_fft = image->do_fft_cuda();
01036 EMDataForCuda left = image_fft->get_data_struct_for_cuda();
01037 EMData::CudaDataLock lock2(image_fft);
01038 EMData *with_fft = with->do_fft_cuda();
01039 EMDataForCuda right = with_fft->get_data_struct_for_cuda();
01040 EMData::CudaDataLock lock3(image_fft);
01041
01042 mean_phase_error_cuda(&left,&right,&wt,hist,norm,hist_pyramid.size());
01043 float result;
01044 float* gpu_result = hist_pyramid[hist_pyramid.size()-1]->get_cuda_data();
01045 cudaError_t error = cudaMemcpy(&result,gpu_result,sizeof(float),cudaMemcpyDeviceToHost);
01046 if ( error != cudaSuccess) throw UnexpectedBehaviorException( "CudaMemcpy (host to device) in the phase comparator failed:" + string(cudaGetErrorString(error)));
01047
01048 delete image_fft; image_fft=0;
01049 delete with_fft; with_fft=0;
01050
01051 for(unsigned int i = 0; i < hist_pyramid.size(); ++i ) {
01052
01053
01054 hist_pyramid[i]->cuda_unlock();
01055 norm_pyramid[i]->cuda_unlock();
01056 }
01057
01058 EXITFUNC;
01059 return result;
01060
01061 }
01062
01063 #endif // EMAN2_USING_CUDA
01064
01065
01066 float FRCCmp::cmp(EMData * image, EMData * with) const
01067 {
01068 ENTERFUNC;
01069 validate_input_args(image, with);
01070
01071 int snrweight = params.set_default("snrweight", 0);
01072 int ampweight = params.set_default("ampweight", 0);
01073 int sweight = params.set_default("sweight", 1);
01074 int nweight = params.set_default("nweight", 0);
01075 int zeromask = params.set_default("zeromask",0);
01076 float minres = params.set_default("minres",500.0f);
01077 float maxres = params.set_default("maxres",10.0f);
01078
01079 if (zeromask) {
01080 image=image->copy();
01081 with=with->copy();
01082
01083 int sz=image->get_xsize()*image->get_ysize()*image->get_zsize();
01084 float *d1=image->get_data();
01085 float *d2=with->get_data();
01086
01087 for (int i=0; i<sz; i++) {
01088 if (d1[i]==0.0 || d2[i]==0.0) { d1[i]=0.0; d2[i]=0.0; }
01089 }
01090
01091 image->update();
01092 with->update();
01093 image->do_fft_inplace();
01094 with->do_fft_inplace();
01095 image->set_attr("free_me",1);
01096 with->set_attr("free_me",1);
01097 }
01098
01099
01100 if (!image->is_complex()) {
01101 image=image->do_fft();
01102 image->set_attr("free_me",1);
01103 }
01104 if (!with->is_complex()) {
01105 with=with->do_fft();
01106 with->set_attr("free_me",1);
01107 }
01108
01109 static vector < float >default_snr;
01110
01111
01112
01113
01114
01115
01116 int ny = image->get_ysize();
01117 int ny2=ny/2+1;
01118
01119 vector < float >fsc;
01120
01121
01122
01123 fsc = image->calc_fourier_shell_correlation(with,1);
01124
01125
01126
01127
01128
01129
01130
01131
01132
01133
01134
01135
01136
01137
01138
01139
01140
01141
01142
01143
01144
01145
01146
01147
01148
01149
01150
01151
01152
01153
01154
01155
01156
01157
01158 vector<float> snr;
01159 if (snrweight) {
01160 Ctf *ctf = NULL;
01161 if (!image->has_attr("ctf")) {
01162 if (!with->has_attr("ctf")) throw InvalidCallException("SNR weight with no CTF parameters");
01163 ctf=with->get_attr("ctf");
01164 }
01165 else ctf=image->get_attr("ctf");
01166
01167 float ds=1.0f/(ctf->apix*ny);
01168 snr=ctf->compute_1d(ny,ds,Ctf::CTF_SNR);
01169 if(ctf) {delete ctf; ctf=0;}
01170 }
01171
01172 vector<float> amp;
01173 if (ampweight) amp=image->calc_radial_dist(ny/2,0,1,0);
01174
01175
01176 float pmin,pmax;
01177 if (minres>0) pmin=((float)image->get_attr("apix_x")*image->get_ysize())/minres;
01178 else pmin=0;
01179 if (maxres>0) pmax=((float)image->get_attr("apix_x")*image->get_ysize())/maxres;
01180 else pmax=0;
01181
01182 double sum=0.0, norm=0.0;
01183
01184 for (int i=0; i<ny/2; i++) {
01185 double weight=1.0;
01186 if (sweight) weight*=fsc[(ny2)*2+i];
01187 if (ampweight) weight*=amp[i];
01188 if (snrweight) weight*=snr[i];
01189 if (pmin>0) weight*=(tanh(5.0*(i-pmin)/pmin)+1.0)/2.0;
01190 if (pmax>0) weight*=(1.0-tanh(i-pmax))/2.0;
01191
01192 sum+=weight*fsc[ny2+i];
01193 norm+=weight;
01194
01195 }
01196
01197
01198 sum/=norm;
01199 if (nweight && with->get_attr_default("ptcl_repr",0) && sum>=0 && sum<1.0) {
01200 sum=sum/(1.0-sum);
01201 sum/=(float)with->get_attr_default("ptcl_repr",0);
01202 sum=sum/(1.0+sum);
01203 }
01204
01205 if (image->has_attr("free_me")) delete image;
01206 if (with->has_attr("free_me")) delete with;
01207
01208 EXITFUNC;
01209
01210
01211
01212
01213
01214 return (float)-sum;
01215 }
01216
01217 void EMAN::dump_cmps()
01218 {
01219 dump_factory < Cmp > ();
01220 }
01221
01222 map<string, vector<string> > EMAN::dump_cmps_list()
01223 {
01224 return dump_factory_list < Cmp > ();
01225 }
01226
01227