EMAN::DotCmp Class Reference
[a function or class that is CUDA enabled]

Use dot product of 2 same-size images to do the comparison. More...

#include <cmp.h>

Inheritance diagram for EMAN::DotCmp:

Inheritance graph
[legend]
Collaboration diagram for EMAN::DotCmp:

Collaboration graph
[legend]
List of all members.

Public Member Functions

float cmp (EMData *image, EMData *with) const
 To compare 'image' with another image passed in through its parameters.
string get_name () const
 Get the Cmp's name.
string get_desc () const
TypeDict get_param_types () const
 Get Cmp parameter information in a dictionary.

Static Public Member Functions

static CmpNEW ()

Static Public Attributes

static const string NAME = "dot"

Detailed Description

Use dot product of 2 same-size images to do the comparison.

// Added mask option PAP 04/23/06 For complex images, it does not check r/i vs a/p.

Author:
Steve Ludtke (sludtke@bcm.tmc.edu)
Date:
2005-07-13
Parameters:
negative Returns -1 * dot product, default true
normalize Returns normalized dot product -1.0 - 1.0

Definition at line 272 of file cmp.h.


Member Function Documentation

float DotCmp::cmp ( EMData image,
EMData with 
) const [virtual]

To compare 'image' with another image passed in through its parameters.

An optional transformation may be used to transform the 2 images.

Parameters:
image The first image to be compared.
with The second image to be comppared.
Returns:
The comparison result. Smaller better by default

Implements EMAN::Cmp.

Definition at line 408 of file cmp.cpp.

References dm, dot_cmp_cuda(), ENTERFUNC, EXITFUNC, EMAN::EMData::get_attr(), EMAN::EMData::get_const_data(), EMAN::EMData::get_xsize(), EMAN::EMData::get_ysize(), EMAN::EMData::get_zsize(), EMAN::Dict::has_key(), EMAN::EMData::is_complex(), EMAN::EMData::is_fftodd(), nx, ny, EMAN::Cmp::params, EMAN::Dict::set_default(), sqrt(), and EMAN::Cmp::validate_input_args().

Referenced by EMAN::EMData::dot().

00409 {
00410         ENTERFUNC;
00411         
00412         validate_input_args(image, with);
00413 
00414         int normalize = params.set_default("normalize", 0);
00415         float negative = (float)params.set_default("negative", 1);
00416         if (negative) negative=-1.0; else negative=1.0;
00417 #ifdef EMAN2_USING_CUDA // SO far only works for real images I put CUDA first to avoid running non CUDA overhead (calls to getdata are expensive!!!!)
00418         if(image->is_complex() && with->is_complex()) {
00419         } else {
00420                 if (image->getcudarwdata() && with->getcudarwdata()) {
00421                         //cout << "CUDA dot cmp" << endl;
00422                         float* maskdata = 0;
00423                         bool has_mask = false;
00424                         EMData* mask = 0;
00425                         if (params.has_key("mask")) {
00426                                 mask = params["mask"];
00427                                 if(mask!=0) {has_mask=true;}
00428                         }
00429                         if(has_mask && !mask->getcudarwdata()){
00430                                 mask->copy_to_cuda();
00431                                 maskdata = mask->getcudarwdata();
00432                         }
00433 
00434                         float result = dot_cmp_cuda(image->getcudarwdata(), with->getcudarwdata(), maskdata, image->get_xsize(), image->get_ysize(), image->get_zsize());
00435                         result *= negative;
00436 
00437                         return result;
00438                         
00439                 }
00440         }
00441 #endif
00442         const float *const x_data = image->get_const_data();
00443         const float *const y_data = with->get_const_data();
00444 
00445         double result = 0.;
00446         long n = 0;
00447         if(image->is_complex() && with->is_complex()) {
00448         // Implemented by PAP  01/09/06 - please do not change.  If in doubts, write/call me.
00449                 int nx  = with->get_xsize();
00450                 int ny  = with->get_ysize();
00451                 int nz  = with->get_zsize();
00452                 nx = (nx - 2 + with->is_fftodd()); // nx is the real-space size of the input image
00453                 int lsd2 = (nx + 2 - nx%2) ; // Extended x-dimension of the complex image
00454 
00455                 int ixb = 2*((nx+1)%2);
00456                 int iyb = ny%2;
00457                 //
00458                 if(nz == 1) {
00459                 //  it looks like it could work in 3D, but does not
00460                 for ( int iz = 0; iz <= nz-1; ++iz) {
00461                         double part = 0.;
00462                         for ( int iy = 0; iy <= ny-1; ++iy) {
00463                                 for ( int ix = 2; ix <= lsd2 - 1 - ixb; ++ix) {
00464                                         size_t ii = ix + (iy  + iz * ny)* lsd2;
00465                                         part += x_data[ii] * double(y_data[ii]);
00466                                 }
00467                         }
00468                         for ( int iy = 1; iy <= ny/2-1 + iyb; ++iy) {
00469                                 size_t ii = (iy  + iz * ny)* lsd2;
00470                                 part += x_data[ii] * double(y_data[ii]);
00471                                 part += x_data[ii+1] * double(y_data[ii+1]);
00472                         }
00473                         if(nx%2 == 0) {
00474                                 for ( int iy = 1; iy <= ny/2-1 + iyb; ++iy) {
00475                                         size_t ii = lsd2 - 2 + (iy  + iz * ny)* lsd2;
00476                                         part += x_data[ii] * double(y_data[ii]);
00477                                         part += x_data[ii+1] * double(y_data[ii+1]);
00478                                 }
00479 
00480                         }
00481                         part *= 2;
00482                         part += x_data[0] * double(y_data[0]);
00483                         if(ny%2 == 0) {
00484                                 size_t ii = (ny/2  + iz * ny)* lsd2;
00485                                 part += x_data[ii] * double(y_data[ii]);
00486                         }
00487                         if(nx%2 == 0) {
00488                                 size_t ii = lsd2 - 2 + (0  + iz * ny)* lsd2;
00489                                 part += x_data[ii] * double(y_data[ii]);
00490                                 if(ny%2 == 0) {
00491                                         int ii = lsd2 - 2 +(ny/2  + iz * ny)* lsd2;
00492                                         part += x_data[ii] * double(y_data[ii]);
00493                                 }
00494                         }
00495                         result += part;
00496                 }
00497                 if( normalize ) {
00498                 //  it looks like it could work in 3D, but does not
00499                 double square_sum1 = 0., square_sum2 = 0.;
00500                 for ( int iz = 0; iz <= nz-1; ++iz) {
00501                         for ( int iy = 0; iy <= ny-1; ++iy) {
00502                                 for ( int ix = 2; ix <= lsd2 - 1 - ixb; ++ix) {
00503                                         size_t ii = ix + (iy  + iz * ny)* lsd2;
00504                                         square_sum1 += x_data[ii] * double(x_data[ii]);
00505                                         square_sum2 += y_data[ii] * double(y_data[ii]);
00506                                 }
00507                         }
00508                         for ( int iy = 1; iy <= ny/2-1 + iyb; ++iy) {
00509                                 size_t ii = (iy  + iz * ny)* lsd2;
00510                                 square_sum1 += x_data[ii] * double(x_data[ii]);
00511                                 square_sum1 += x_data[ii+1] * double(x_data[ii+1]);
00512                                 square_sum2 += y_data[ii] * double(y_data[ii]);
00513                                 square_sum2 += y_data[ii+1] * double(y_data[ii+1]);
00514                         }
00515                         if(nx%2 == 0) {
00516                                 for ( int iy = 1; iy <= ny/2-1 + iyb; ++iy) {
00517                                         size_t ii = lsd2 - 2 + (iy  + iz * ny)* lsd2;
00518                                         square_sum1 += x_data[ii] * double(x_data[ii]);
00519                                         square_sum1 += x_data[ii+1] * double(x_data[ii+1]);
00520                                         square_sum2 += y_data[ii] * double(y_data[ii]);
00521                                         square_sum2 += y_data[ii+1] * double(y_data[ii+1]);
00522                                 }
00523 
00524                         }
00525                         square_sum1 *= 2;
00526                         square_sum1 += x_data[0] * double(x_data[0]);
00527                         square_sum2 *= 2;
00528                         square_sum2 += y_data[0] * double(y_data[0]);
00529                         if(ny%2 == 0) {
00530                                 int ii = (ny/2  + iz * ny)* lsd2;
00531                                 square_sum1 += x_data[ii] * double(x_data[ii]);
00532                                 square_sum2 += y_data[ii] * double(y_data[ii]);
00533                         }
00534                         if(nx%2 == 0) {
00535                                 int ii = lsd2 - 2 + (0  + iz * ny)* lsd2;
00536                                 square_sum1 += x_data[ii] * double(x_data[ii]);
00537                                 square_sum2 += y_data[ii] * double(y_data[ii]);
00538                                 if(ny%2 == 0) {
00539                                         int ii = lsd2 - 2 +(ny/2  + iz * ny)* lsd2;
00540                                         square_sum1 += x_data[ii] * double(x_data[ii]);
00541                                         square_sum2 += y_data[ii] * double(y_data[ii]);
00542                                 }
00543                         }
00544                 }
00545                 result /= sqrt(square_sum1*square_sum2);
00546                 } 
00547                         else  result /= ((float)nx*(float)ny*(float)nz*(float)nx*(float)ny*(float)nz);
00548                         // This concludes the 2D part.
00549                 } else {
00550                 // The 3D code
00551                 int ky, kz;
00552                 int ny2 = ny/2; int nz2 = nz/2;
00553                 for ( int iz = 0; iz <= nz-1; ++iz) {
00554                         if(iz>nz2) kz=iz-nz; else kz=iz;
00555                         for ( int iy = 0; iy <= ny-1; ++iy) {
00556                                 if(iy>ny2) ky=iy-ny; else ky=iy;
00557                                 for (int ix = 0; ix <= lsd2-1; ++ix) {
00558                                                 int p = 2;
00559                                                 if  (ix <= 1 || ix >= lsd2-2 && nx%2 == 0)   p = 1;
00560                                                 size_t ii = ix + (iy  + iz * ny)* (size_t)lsd2;
00561                                                 result += p * x_data[ii] * double(y_data[ii]);
00562                                 }
00563                         }
00564                 }
00565                 if( normalize ) {
00566                 double square_sum1 = 0., square_sum2 = 0.;
00567                 int ky, kz;
00568                 int ny2 = ny/2; int nz2 = nz/2;
00569                 for ( int iz = 0; iz <= nz-1; ++iz) {
00570                         if(iz>nz2) kz=iz-nz; else kz=iz;
00571                         for ( int iy = 0; iy <= ny-1; ++iy) {
00572                                 if(iy>ny2) ky=iy-ny; else ky=iy;
00573                                 for (int ix = 0; ix <= lsd2-1; ++ix) {
00574                                                 int p = 2;
00575                                                 if  (ix <= 1 || ix >= lsd2-2 && nx%2 == 0)   p = 1;
00576                                                 size_t ii = ix + (iy  + iz * ny)* (size_t)lsd2;
00577                                                 square_sum1 += p * x_data[ii] * double(x_data[ii]);
00578                                                 square_sum2 += p * y_data[ii] * double(y_data[ii]);
00579                                 }
00580                         }
00581                 }
00582                 result /= sqrt(square_sum1*square_sum2);
00583                 } 
00584                         else result /= ((float)nx*(float)ny*(float)nz*(float)nx*(float)ny*(float)nz);
00585                 }
00586         } else {
00587                 // This part is for when two images are real, which is much easier because 2-D or 3-D
00588                 // is the same code.
00589                 size_t totsize = (size_t)image->get_xsize() * image->get_ysize() * image->get_zsize();
00590 
00591                 double square_sum1 = 0., square_sum2 = 0.;
00592 
00593                 if (params.has_key("mask")) {
00594                         EMData* mask;
00595                         mask = params["mask"];
00596                         const float *const dm = mask->get_const_data();
00597                         if (normalize) {
00598                                 for (size_t i = 0; i < totsize; i++) {
00599                                         if (dm[i] > 0.5) {
00600                                                 square_sum1 += x_data[i]*double(x_data[i]);
00601                                                 square_sum2 += y_data[i]*double(y_data[i]);
00602                                                 result += x_data[i]*double(y_data[i]);
00603                                         }
00604                                 }
00605                         } else {
00606                                 for (size_t i = 0; i < totsize; i++) {
00607                                         if (dm[i] > 0.5) {
00608                                                 result += x_data[i]*double(y_data[i]);
00609                                                 n++;
00610                                         }
00611                                 }
00612                         }
00613                 } else {
00614                         // this little bit of manual loop unrolling makes the dot product as fast as sqeuclidean with -O2
00615                         for (size_t i=0; i<totsize; i++) result+=x_data[i]*y_data[i];
00616 
00617                         if (normalize) {
00618                                 square_sum1 = image->get_attr("square_sum");
00619                                 square_sum2 = with->get_attr("square_sum");
00620                         } else n = totsize;
00621                 }
00622                 if (normalize) result /= (sqrt(square_sum1*square_sum2)); else result /= n;
00623         }
00624 
00625 
00626         EXITFUNC;
00627         return (float) (negative*result);
00628 }

string EMAN::DotCmp::get_desc (  )  const [inline, virtual]

Implements EMAN::Cmp.

Definition at line 282 of file cmp.h.

00283                 {
00284                         return "Dot product (default -1 * dot product)";
00285                 }

string EMAN::DotCmp::get_name (  )  const [inline, virtual]

Get the Cmp's name.

Each Cmp is identified by a unique name.

Returns:
The Cmp's name.

Implements EMAN::Cmp.

Definition at line 277 of file cmp.h.

References NAME.

00278                 {
00279                         return NAME;
00280                 }

TypeDict EMAN::DotCmp::get_param_types (  )  const [inline, virtual]

Get Cmp parameter information in a dictionary.

Each parameter has one record in the dictionary. Each record contains its name, data-type, and description.

Returns:
A dictionary containing the parameter info.

Implements EMAN::Cmp.

Definition at line 292 of file cmp.h.

References EMAN::EMObject::EMDATA, EMAN::EMObject::INT, and EMAN::TypeDict::put().

00293                 {
00294                         TypeDict d;
00295                         d.put("negative", EMObject::INT, "If set, returns -1 * dot product. Set by default so smaller is better");
00296                         d.put("normalize", EMObject::INT, "If set, returns normalized dot product (cosine of the angle) -1.0 - 1.0.");
00297                         d.put("mask", EMObject::EMDATA, "image mask");
00298                         return d;
00299                 }

static Cmp* EMAN::DotCmp::NEW (  )  [inline, static]

Definition at line 287 of file cmp.h.

00288                 {
00289                         return new DotCmp();
00290                 }


Member Data Documentation

const string DotCmp::NAME = "dot" [static]

Definition at line 301 of file cmp.h.

Referenced by get_name().


The documentation for this class was generated from the following files:
Generated on Tue Jun 11 12:42:57 2013 for EMAN2 by  doxygen 1.4.7