// Copyright (C) 2002 Ronan Collobert (collober@iro.umontreal.ca)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

#include "TwoClassFormat.h"

namespace Torch {

TwoClassFormat::TwoClassFormat(DataSet *data)
{
  if(data->n_targets != 1)
    warning("TwoClassFormat: the data has %d ouputs", data->n_targets);
  
  int n_set = 0;
  for(int i = 0; i < data->n_examples; i++)
  {
    data->setExample(i);
    
    bool flag = false;
    for(int k = 0; k < n_set; k++)
    {
      if(((real *)data->targets)[0] == tabclasses[k])
        flag = true;
    }

    if(!flag)
    {
      if(n_set == 2)
        error("TwoClassFormat: you have more than two classes");

      tabclasses[n_set++] = ((real *)data->targets)[0];
    }
  }

  switch(n_set)
  {
    case 0:
      warning("TwoClassFormat: you have no examples");
      tabclasses[0] = 0;
      tabclasses[1] = 0;
      break;
    case 1:
      warning("TwoClassFormat: you have only one class [%g]", tabclasses[0]);
      tabclasses[1] = tabclasses[0];
      break;
    case 2:
      if(tabclasses[0] > tabclasses[1])
      {
        real z = tabclasses[1];
        tabclasses[1] = tabclasses[0];
        tabclasses[0] = z;
      }
      message("TwoClassFormat: two classes detected [%g and %g]", tabclasses[0], tabclasses[1]);
      break;
  }

  // He He He...
  n_classes = 2;
  class_labels = (real **)xalloc(sizeof(real *)*n_classes);
  for(int i = 0; i < n_classes; i++)
    class_labels[i] = tabclasses+i;
}

TwoClassFormat::TwoClassFormat(real class_1, real class_2)
{
  tabclasses[0] = class_1;
  tabclasses[1] = class_2;
  n_classes = 2;
  class_labels = (real **)xalloc(sizeof(real *)*n_classes);
  for(int i = 0; i < n_classes; i++)
    class_labels[i] = tabclasses+i;
}

int TwoClassFormat::getOutputSize()
{
  return 1;
}

void TwoClassFormat::fromOneHot(List *outputs, List *one_hot_outputs)
{
  real *one = (real*)one_hot_outputs->ptr;
  real *out = (real*)outputs->ptr;
  if (one_hot_outputs->n == 2) 
    *out = one[0] - one[1];
  else 
    *out = *one - *((real*)one_hot_outputs->next->ptr);
  if (tabclasses[1]>tabclasses[0]) 
    *out = - *out;
}

void TwoClassFormat::toOneHot(List *outputs, List *one_hot_outputs)
{
  real *one = (real*)one_hot_outputs->ptr;
  real *out = (real*)outputs->ptr;
  int maxclass = (tabclasses[1]>tabclasses[0]);
  int minclass = (tabclasses[0]>tabclasses[1]);
  if (one_hot_outputs->n == 2) {
    one[0] = fabs(*out - tabclasses[maxclass]);
    one[1] = fabs(*out - tabclasses[minclass]);
  } else {
    *one = fabs(*out - tabclasses[maxclass]);
    *((real*)one_hot_outputs->next->ptr) = fabs(*out - tabclasses[minclass]);
  }
}

int TwoClassFormat::getTargetClass(void *target)
{
  real out = *((real *)target);
  
  return(fabs(out - tabclasses[0]) > fabs(out - tabclasses[1]) ? 1 : 0);
}

int TwoClassFormat::getOutputClass(List *outputs)
{
  real out = *((real *)outputs->ptr);
  
  return(fabs(out - tabclasses[0]) > fabs(out - tabclasses[1]) ? 1 : 0);
}

TwoClassFormat::~TwoClassFormat()
{
  free(class_labels);
}

}

