Main Page | Modules | Namespace List | Class Hierarchy | Alphabetical List | Class List | Directories | File List | Namespace Members | Class Members | File Members

EMAN::DotCmp Class Reference
[a function or class that is CUDA enabled]

Use dot product of 2 same-size images to do the comparison. More...

#include <cmp.h>

Inheritance diagram for EMAN::DotCmp:

Inheritance graph
[legend]
Collaboration diagram for EMAN::DotCmp:

Collaboration graph
[legend]
List of all members.

Public Member Functions

float cmp (EMData *image, EMData *with) const
 To compare 'image' with another image passed in through its parameters.
string get_name () const
 Get the Cmp's name.
string get_desc () const
TypeDict get_param_types () const
 Get Cmp parameter information in a dictionary.

Static Public Member Functions

CmpNEW ()

Static Public Attributes

const string NAME = "dot"

Detailed Description

Use dot product of 2 same-size images to do the comparison.

// Added mask option PAP 04/23/06 For complex images, it does not check r/i vs a/p.

Author:
Steve Ludtke (sludtke@bcm.tmc.edu)
Date:
2005-07-13
Parameters:
negative Returns -1 * dot product, default true
normalize Returns normalized dot product -1.0 - 1.0

Definition at line 272 of file cmp.h.


Member Function Documentation

float DotCmp::cmp EMData image,
EMData with
const [virtual]
 

To compare 'image' with another image passed in through its parameters.

An optional transformation may be used to transform the 2 images.

Parameters:
image The first image to be compared.
with The second image to be comppared.
Returns:
The comparison result. Smaller better by default

Implements EMAN::Cmp.

Definition at line 407 of file cmp.cpp.

References dm, dot_cmp_cuda(), EMAN::EMData::get_attr(), EMAN::EMData::get_const_data(), EMAN::EMData::get_xsize(), EMAN::EMData::get_ysize(), EMAN::EMData::get_zsize(), EMAN::Dict::has_key(), EMAN::EMData::is_complex(), EMAN::EMData::is_fftodd(), nx, ny, EMAN::Dict::set_default(), sqrt(), and EMAN::Cmp::validate_input_args().

Referenced by EMAN::EMData::dot().

00408 {
00409         ENTERFUNC;
00410         
00411         validate_input_args(image, with);
00412 
00413         int normalize = params.set_default("normalize", 0);
00414         float negative = (float)params.set_default("negative", 1);
00415         if (negative) negative=-1.0; else negative=1.0;
00416 #ifdef EMAN2_USING_CUDA // SO far only works for real images I put CUDA first to avoid running non CUDA overhead (calls to getdata are expensive!!!!)
00417         if(image->is_complex() && with->is_complex()) {
00418         } else {
00419                 if (image->getcudarwdata() && with->getcudarwdata()) {
00420                         //cout << "CUDA dot cmp" << endl;
00421                         float* maskdata = 0;
00422                         bool has_mask = false;
00423                         EMData* mask = 0;
00424                         if (params.has_key("mask")) {
00425                                 mask = params["mask"];
00426                                 if(mask!=0) {has_mask=true;}
00427                         }
00428                         if(has_mask && !mask->getcudarwdata()){
00429                                 mask->copy_to_cuda();
00430                                 maskdata = mask->getcudarwdata();
00431                         }
00432 
00433                         float result = dot_cmp_cuda(image->getcudarwdata(), with->getcudarwdata(), maskdata, image->get_xsize(), image->get_ysize(), image->get_zsize());
00434                         result *= negative;
00435 
00436                         return result;
00437                         
00438                 }
00439         }
00440 #endif
00441         const float *const x_data = image->get_const_data();
00442         const float *const y_data = with->get_const_data();
00443 
00444         double result = 0.;
00445         long n = 0;
00446         if(image->is_complex() && with->is_complex()) {
00447         // Implemented by PAP  01/09/06 - please do not change.  If in doubts, write/call me.
00448                 int nx  = with->get_xsize();
00449                 int ny  = with->get_ysize();
00450                 int nz  = with->get_zsize();
00451                 nx = (nx - 2 + with->is_fftodd()); // nx is the real-space size of the input image
00452                 int lsd2 = (nx + 2 - nx%2) ; // Extended x-dimension of the complex image
00453 
00454                 int ixb = 2*((nx+1)%2);
00455                 int iyb = ny%2;
00456                 //
00457                 if(nz == 1) {
00458                 //  it looks like it could work in 3D, but does not
00459                 for ( int iz = 0; iz <= nz-1; ++iz) {
00460                         double part = 0.;
00461                         for ( int iy = 0; iy <= ny-1; ++iy) {
00462                                 for ( int ix = 2; ix <= lsd2 - 1 - ixb; ++ix) {
00463                                         size_t ii = ix + (iy  + iz * ny)* lsd2;
00464                                         part += x_data[ii] * double(y_data[ii]);
00465                                 }
00466                         }
00467                         for ( int iy = 1; iy <= ny/2-1 + iyb; ++iy) {
00468                                 size_t ii = (iy  + iz * ny)* lsd2;
00469                                 part += x_data[ii] * double(y_data[ii]);
00470                                 part += x_data[ii+1] * double(y_data[ii+1]);
00471                         }
00472                         if(nx%2 == 0) {
00473                                 for ( int iy = 1; iy <= ny/2-1 + iyb; ++iy) {
00474                                         size_t ii = lsd2 - 2 + (iy  + iz * ny)* lsd2;
00475                                         part += x_data[ii] * double(y_data[ii]);
00476                                         part += x_data[ii+1] * double(y_data[ii+1]);
00477                                 }
00478 
00479                         }
00480                         part *= 2;
00481                         part += x_data[0] * double(y_data[0]);
00482                         if(ny%2 == 0) {
00483                                 size_t ii = (ny/2  + iz * ny)* lsd2;
00484                                 part += x_data[ii] * double(y_data[ii]);
00485                         }
00486                         if(nx%2 == 0) {
00487                                 size_t ii = lsd2 - 2 + (0  + iz * ny)* lsd2;
00488                                 part += x_data[ii] * double(y_data[ii]);
00489                                 if(ny%2 == 0) {
00490                                         int ii = lsd2 - 2 +(ny/2  + iz * ny)* lsd2;
00491                                         part += x_data[ii] * double(y_data[ii]);
00492                                 }
00493                         }
00494                         result += part;
00495                 }
00496                 if( normalize ) {
00497                 //  it looks like it could work in 3D, but does not
00498                 double square_sum1 = 0., square_sum2 = 0.;
00499                 for ( int iz = 0; iz <= nz-1; ++iz) {
00500                         for ( int iy = 0; iy <= ny-1; ++iy) {
00501                                 for ( int ix = 2; ix <= lsd2 - 1 - ixb; ++ix) {
00502                                         size_t ii = ix + (iy  + iz * ny)* lsd2;
00503                                         square_sum1 += x_data[ii] * double(x_data[ii]);
00504                                         square_sum2 += y_data[ii] * double(y_data[ii]);
00505                                 }
00506                         }
00507                         for ( int iy = 1; iy <= ny/2-1 + iyb; ++iy) {
00508                                 size_t ii = (iy  + iz * ny)* lsd2;
00509                                 square_sum1 += x_data[ii] * double(x_data[ii]);
00510                                 square_sum1 += x_data[ii+1] * double(x_data[ii+1]);
00511                                 square_sum2 += y_data[ii] * double(y_data[ii]);
00512                                 square_sum2 += y_data[ii+1] * double(y_data[ii+1]);
00513                         }
00514                         if(nx%2 == 0) {
00515                                 for ( int iy = 1; iy <= ny/2-1 + iyb; ++iy) {
00516                                         size_t ii = lsd2 - 2 + (iy  + iz * ny)* lsd2;
00517                                         square_sum1 += x_data[ii] * double(x_data[ii]);
00518                                         square_sum1 += x_data[ii+1] * double(x_data[ii+1]);
00519                                         square_sum2 += y_data[ii] * double(y_data[ii]);
00520                                         square_sum2 += y_data[ii+1] * double(y_data[ii+1]);
00521                                 }
00522 
00523                         }
00524                         square_sum1 *= 2;
00525                         square_sum1 += x_data[0] * double(x_data[0]);
00526                         square_sum2 *= 2;
00527                         square_sum2 += y_data[0] * double(y_data[0]);
00528                         if(ny%2 == 0) {
00529                                 int ii = (ny/2  + iz * ny)* lsd2;
00530                                 square_sum1 += x_data[ii] * double(x_data[ii]);
00531                                 square_sum2 += y_data[ii] * double(y_data[ii]);
00532                         }
00533                         if(nx%2 == 0) {
00534                                 int ii = lsd2 - 2 + (0  + iz * ny)* lsd2;
00535                                 square_sum1 += x_data[ii] * double(x_data[ii]);
00536                                 square_sum2 += y_data[ii] * double(y_data[ii]);
00537                                 if(ny%2 == 0) {
00538                                         int ii = lsd2 - 2 +(ny/2  + iz * ny)* lsd2;
00539                                         square_sum1 += x_data[ii] * double(x_data[ii]);
00540                                         square_sum2 += y_data[ii] * double(y_data[ii]);
00541                                 }
00542                         }
00543                 }
00544                 result /= sqrt(square_sum1*square_sum2);
00545                 } else  result /= ((float)nx*(float)ny*(float)nz*(float)nx*(float)ny*(float)nz);
00546 
00547                 } else { //This 3D code is incorrect, but it is the best I can do now 01/09/06 PAP
00548                 int ky, kz;
00549                 int ny2 = ny/2; int nz2 = nz/2;
00550                 for ( int iz = 0; iz <= nz-1; ++iz) {
00551                         if(iz>nz2) kz=iz-nz; else kz=iz;
00552                         for ( int iy = 0; iy <= ny-1; ++iy) {
00553                                 if(iy>ny2) ky=iy-ny; else ky=iy;
00554                                 for ( int ix = 0; ix <= lsd2-1; ++ix) {
00555                                         // Skip Friedel related values
00556                                         if(ix>0 || (kz>=0 && (ky>=0 || kz!=0))) {
00557                                                 size_t ii = ix + (iy  + iz * ny)* (size_t)lsd2;
00558                                                 result += x_data[ii] * double(y_data[ii]);
00559                                         }
00560                                 }
00561                         }
00562                 }
00563                 if( normalize ) {
00564                 //  still incorrect
00565                 double square_sum1 = 0., square_sum2 = 0.;
00566                 int ky, kz;
00567                 int ny2 = ny/2; int nz2 = nz/2;
00568                 for ( int iz = 0; iz <= nz-1; ++iz) {
00569                         if(iz>nz2) kz=iz-nz; else kz=iz;
00570                         for ( int iy = 0; iy <= ny-1; ++iy) {
00571                                 if(iy>ny2) ky=iy-ny; else ky=iy;
00572                                 for ( int ix = 0; ix <= lsd2-1; ++ix) {
00573                                         // Skip Friedel related values
00574                                         if(ix>0 || (kz>=0 && (ky>=0 || kz!=0))) {
00575                                                 size_t ii = ix + (iy  + iz * ny)* (size_t)lsd2;
00576                                                 square_sum1 += x_data[ii] * double(x_data[ii]);
00577                                                 square_sum2 += y_data[ii] * double(y_data[ii]);
00578                                         }
00579                                 }
00580                         }
00581                 }
00582                 result /= sqrt(square_sum1*square_sum2);
00583                 } else result /= ((float)nx*(float)ny*(float)nz*(float)nx*(float)ny*(float)nz/2);
00584                 }
00585         } else {
00586                 
00587                 size_t totsize = (size_t)image->get_xsize() * image->get_ysize() * image->get_zsize();
00588 
00589                 double square_sum1 = 0., square_sum2 = 0.;
00590 
00591                 if (params.has_key("mask")) {
00592                         EMData* mask;
00593                         mask = params["mask"];
00594                         const float *const dm = mask->get_const_data();
00595                         if (normalize) {
00596                                 for (size_t i = 0; i < totsize; i++) {
00597                                         if (dm[i] > 0.5) {
00598                                                 square_sum1 += x_data[i]*double(x_data[i]);
00599                                                 square_sum2 += y_data[i]*double(y_data[i]);
00600                                                 result += x_data[i]*double(y_data[i]);
00601                                         }
00602                                 }
00603                         } else {
00604                                 for (size_t i = 0; i < totsize; i++) {
00605                                         if (dm[i] > 0.5) {
00606                                                 result += x_data[i]*double(y_data[i]);
00607                                                 n++;
00608                                         }
00609                                 }
00610                         }
00611                 } else {
00612                         // this little bit of manual loop unrolling makes the dot product as fast as sqeuclidean with -O2
00613                         for (size_t i=0; i<totsize; i++) result+=x_data[i]*y_data[i];
00614 
00615                         if (normalize) {
00616                                 square_sum1 = image->get_attr("square_sum");
00617                                 square_sum2 = with->get_attr("square_sum");
00618                         } else n = totsize;
00619                 }
00620                 if (normalize) result /= (sqrt(square_sum1*square_sum2)); else result /= n;
00621         }
00622 
00623 
00624         EXITFUNC;
00625         return (float) (negative*result);
00626 }

string EMAN::DotCmp::get_desc  )  const [inline, virtual]
 

Implements EMAN::Cmp.

Definition at line 282 of file cmp.h.

00283                 {
00284                         return "Dot product (default -1 * dot product)";
00285                 }

string EMAN::DotCmp::get_name  )  const [inline, virtual]
 

Get the Cmp's name.

Each Cmp is identified by a unique name.

Returns:
The Cmp's name.

Implements EMAN::Cmp.

Definition at line 277 of file cmp.h.

00278                 {
00279                         return NAME;
00280                 }

TypeDict EMAN::DotCmp::get_param_types  )  const [inline, virtual]
 

Get Cmp parameter information in a dictionary.

Each parameter has one record in the dictionary. Each record contains its name, data-type, and description.

Returns:
A dictionary containing the parameter info.

Implements EMAN::Cmp.

Definition at line 292 of file cmp.h.

References EMAN::TypeDict::put().

00293                 {
00294                         TypeDict d;
00295                         d.put("negative", EMObject::INT, "If set, returns -1 * dot product. Set by default so smaller is better");
00296                         d.put("normalize", EMObject::INT, "If set, returns normalized dot product (cosine of the angle) -1.0 - 1.0.");
00297                         d.put("mask", EMObject::EMDATA, "image mask");
00298                         return d;
00299                 }

Cmp* EMAN::DotCmp::NEW  )  [inline, static]
 

Definition at line 287 of file cmp.h.

00288                 {
00289                         return new DotCmp();
00290                 }


Member Data Documentation

const string DotCmp::NAME = "dot" [static]
 

Definition at line 52 of file cmp.cpp.


The documentation for this class was generated from the following files:
Generated on Thu Nov 17 12:45:26 2011 for EMAN2 by  doxygen 1.3.9.1