// Copyright (C) 2002 Samy Bengio (bengio@idiap.ch)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

#include "HtkFileDataSet.h"

namespace Torch {

HtkFileDataSet::HtkFileDataSet(char *file, int window_size_, int max_load)
{
  window_size = window_size_;
  n_files = 1;
  htk = (IOHtk**)xalloc(sizeof(IOHtk*)*n_files);
  htk[0] = new IOHtk(file,max_load);
  n_inputs = window_size*htk[0]->n_cols;
  n_targets = n_inputs;
  all_inputs = NULL;
  all_targets = NULL;
  prepareData();
}

HtkFileDataSet::HtkFileDataSet(char **files, int n_files_, int window_size_, int max_load)
{
  window_size = window_size_;
  n_files = n_files_;
  htk = (IOHtk**)xalloc(sizeof(IOHtk*)*n_files);
  for (int i=0;i<n_files;i++) {
    htk[i] = new IOHtk(files[i],max_load);
  }
  n_inputs = window_size*htk[0]->n_cols;
  n_targets = n_inputs;
  all_inputs = NULL;
  all_targets = NULL;
  prepareData();
}

void HtkFileDataSet::write(char* dir_to_save){
	for(int i = 0;i < n_real_examples;i++)
		htk[i]->write(dir_to_save);
}

void HtkFileDataSet::prepareData(bool* input_to_keep)
{
  // first count the number of examples
  n_real_examples = 0;
  for (int i=0;i<n_files;i++) {
    n_real_examples += htk[i]->n_lines - window_size + 1;
  }

  // then create all_inputs and all_targets
  all_inputs = (real**)xrealloc(all_inputs,sizeof(real*)*n_real_examples);
  int k=0;
	bool* in_to_keep = input_to_keep;
  for (int i=0;i<n_files;i++) {
    for (int j=0;j<htk[i]->n_lines - window_size + 1;j++)
			if(!input_to_keep || *in_to_keep++)
				all_inputs[k++] = &htk[i]->data[j*n_inputs/window_size];
  }
	n_examples = n_real_examples = k;
	//link inputs and targets
  all_targets = all_inputs;
}

void HtkFileDataSet::init()
{
  DataSet::init();
  normalize();

  addToList(&inputs, n_inputs, NULL);
}

void HtkFileDataSet::normalize()
{
  // we first need to fool the StdDataSet::normalize by modifying
  // the number of window_size
  if (window_size == 1) {
    StdDataSet::normalize();
  } else {
    int old_n_real_examples = n_real_examples;
    int old_n_inputs = n_inputs;
    n_inputs = n_inputs/window_size;
    n_targets = n_inputs;
    n_real_examples = 0;
    for (int i=0;i<n_files;i++) {
      n_real_examples += htk[i]->n_lines;
    }
    real** old_all_inputs = all_inputs;
    all_inputs = (real**)xalloc(sizeof(real*)*n_real_examples);
    int k=0;
    for (int i=0;i<n_files;i++) {
      for (int j=0;j<htk[i]->n_lines;j++,k++)
        all_inputs[k] = &htk[i]->data[j*n_inputs];
    }
    all_targets = all_inputs;
    StdDataSet::normalize();
    free(all_inputs);
    all_inputs = old_all_inputs;
    all_targets = all_inputs;
    n_real_examples = old_n_real_examples;
    n_inputs = old_n_inputs;
    n_targets = n_inputs;
  }
}

void HtkFileDataSet::normalizeUsingDataSet(StdDataSet *data_norm)
{
  // we first need to fool the StdDataSet::normalize by modifying
  // the number of window_size
  if (window_size == 1) {
    StdDataSet::normalizeUsingDataSet(data_norm);
  } else if (data_norm->norm_inputs) {
    int old_n_real_examples = n_real_examples;
    int old_n_inputs = n_inputs;
    n_inputs = n_inputs/window_size;
    n_targets = n_inputs;
    n_real_examples = 0;
    for (int i=0;i<n_files;i++) {
      n_real_examples += htk[i]->n_lines;
    }
    real** old_all_inputs = all_inputs;
    all_inputs = (real**)xalloc(sizeof(real*)*n_real_examples);
    int k=0;
    for (int i=0;i<n_files;i++) {
      for (int j=0;j<htk[i]->n_lines;j++,k++)
      all_inputs[k] = &htk[i]->data[j*n_inputs];
    }
    all_targets = all_inputs;

    for(int i = 0; i < n_real_examples; i++) {
      for(int d = 0; d < n_inputs; d++)
        all_inputs[i][d] = (all_inputs[i][d]-data_norm->mean_i[d])/data_norm->stdv_i[d];
    }

    free(all_inputs);
    all_inputs = old_all_inputs;
    all_targets = all_inputs;
    n_real_examples = old_n_real_examples;
    n_inputs = old_n_inputs;
    n_targets = n_inputs;
  }
}

int HtkFileDataSet::removeUnlikelyFrames(Distribution* likely_distr, Distribution* unlikely_distr,bool* mask){
	likely_distr->eMIterInitialize();
	unlikely_distr->eMIterInitialize();
	unlikely_distr->eMSequenceInitialize(NULL);
	likely_distr->eMSequenceInitialize(NULL);
	//compute likelihood for each data
	real* unlikely_lp = (real*) xalloc(sizeof(real)*n_real_examples);
	real* likely_lp = (real*) xalloc(sizeof(real)*n_real_examples);
	for(int i=0;i<n_real_examples;i++){
		unlikely_lp[i]=unlikely_distr->frameLogProbability(all_inputs[i],NULL,0);
		likely_lp[i]=likely_distr->frameLogProbability(all_inputs[i],NULL,0);
	}
	// range-1 = number of frame before and after the current frame to means
	// the decision
	// compute inputs to remove
	int tot = n_real_examples;
	int range = 5;
	bool* inputs_to_keep = (bool*)xcalloc(n_real_examples,sizeof(bool));

	for(int j=range;j<n_real_examples-range;j++){
		real p_unlikely = unlikely_lp[j];
		real p_likely = likely_lp[j];
		for(int k=1;k<range+1;k++){
			p_unlikely +=  unlikely_lp[j+k] + unlikely_lp[j-k];
			p_likely += likely_lp[j+k] + likely_lp[j-k];
		}
		if(p_likely >= p_unlikely)
			inputs_to_keep[j] = true;
	}
	free(likely_lp);
	free(unlikely_lp);

	if(mask){
		//remove unsued datas ie energy
		for(int i=0;i<n_files;i++){
			IOHtk* h = htk[i];
			real* vect = h->data;
			int k=0;
			for(int j=0;j<h->n_lines;j++){
				for(int l=0;l<h->n_cols;l++)
					if(!mask[l]){
						vect[k] = vect[l+h->n_cols*j];
						k++;
					}
			}
		}
		n_inputs = 0;
		for(int l=0;l<htk[0]->n_cols;l++)
			if(!mask[l]){
				n_inputs++;
			}
		n_inputs = n_inputs * window_size;
	}

	n_targets = n_inputs;
	// reset the all_inputs pointers
	prepareData(inputs_to_keep);	
	free(inputs_to_keep);
	return (tot-n_real_examples);
}



void HtkFileDataSet::loadFILE(FILE *file)
{
  if(norm_inputs) {
    int old_n_real_examples = n_real_examples;
    int old_n_inputs = n_inputs;
    n_inputs = n_inputs/window_size;
    n_targets = n_inputs;
    n_real_examples = 0;
    for (int i=0;i<n_files;i++) {
      n_real_examples += htk[i]->n_lines;
    }
    real** old_all_inputs = all_inputs;
    all_inputs = (real**)xalloc(sizeof(real*)*n_real_examples);
    int k=0;
    for (int i=0;i<n_files;i++) {
      for (int j=0;j<htk[i]->n_lines;j++,k++)
      all_inputs[k] = &htk[i]->data[j*n_inputs];
    }
    all_targets = all_inputs;

    for(int i = 0; i < n_real_examples; i++) {
      for(int d = 0; d < n_inputs; d++)
        all_inputs[i][d] = all_inputs[i][d]*stdv_i[d]+mean_i[d];
    }

    xfread(mean_i, sizeof(real), n_inputs, file);
    xfread(stdv_i, sizeof(real), n_inputs, file);
  
    for(int i = 0; i < n_real_examples; i++) {
      for(int d = 0; d < n_inputs; d++)
        all_inputs[i][d] = (all_inputs[i][d]-mean_i[d])/stdv_i[d];
    }

    free(all_inputs);
    all_inputs = old_all_inputs;
    all_targets = all_inputs;
    n_real_examples = old_n_real_examples;
    n_inputs = old_n_inputs;
    n_targets = n_inputs;
  }
}

void HtkFileDataSet::saveFILE(FILE *file)
{
  if(norm_inputs) {
    xfwrite(mean_i, sizeof(real), n_inputs/window_size, file);
    xfwrite(stdv_i, sizeof(real), n_inputs/window_size, file);
  }
}

void HtkFileDataSet::createMaskFromParam(bool* mask)
{
  htk[0]->createMaskFromParam(mask);
} 

HtkFileDataSet::~HtkFileDataSet()
{
  for (int i=0;i<n_files;i++)
    delete htk[i];
  free(htk);
  free(all_inputs);
}

}

