// Copyright (C) 2002 Samy Bengio (bengio@idiap.ch)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA


#include "SpeechHMM.h"
#include "log_add.h"
#include "PhonemeSeqDataSet.h"

namespace Torch {

SpeechHMM::SpeechHMM(int n_models_, HMM **models_, char** model_names_, Dictionary* dict_, Grammar* grammar_, real word_entrance_penalty_, EMTrainer* model_trainer_) : HMM(1,(Distribution**)models_,models_[0]->prior_transitions,models_[0]->data,NULL,NULL,0)
{
  n_models = n_models_;
  models = models_;
  if (models) {
    n_observations = models[1]->n_observations;
    n_inputs = models[1]->n_inputs;
  }
  model_names = model_names_;
  model_trainer = model_trainer_;
  dict = dict_;
  grammar = grammar_;
  edit_distance = new EditDistance();

  target_word_sequence = NULL;
  target_word_sequence_size = 0;
  target_word_sequence_max_size = 0;

  word_entrance_penalty = word_entrance_penalty_;
}

void SpeechHMM::loadFILE(FILE *file)
{
  for (int i=0;i<n_models;i++)
    models[i]->loadFILE(file);
}

void SpeechHMM::saveFILE(FILE *file)
{
  for (int i=0;i<n_models;i++)
    models[i]->saveFILE(file);
}

void SpeechHMM::allocateMemory()
{
  n_params = numberOfParams();
  addToList(&outputs,n_outputs,(real*)xalloc(sizeof(real)*n_outputs));
  for (int i=0;i<n_models;i++) {
    addToList(&params,models[i]->params);
    addToList(&der_params,models[i]->der_params);
  }
  // find the longest sequence in the dataset
  // in terms of frames and number of states of the graph
  max_n_frames = 3;
  max_n_states = 2;
  if (data) {
    for (int i=0;i<data->n_examples;i++) {
      data->setExample(i);
      if (data->n_frames+2 > max_n_frames)
        max_n_frames = data->n_frames+2;
      int current_n_states = 2;
      for (int j=0;j<data->n_seqtargets;j++) {
        int word = (int)data->seqtargets[j][0];
        for (int k=0;k<dict->word_length[word];k++)
          current_n_states += models[dict->words[word][k]]->n_states-2;
      }
      if (current_n_states > max_n_states)
        max_n_states = current_n_states;
    }
  }
  log_probabilities_s = (real**)xalloc(sizeof(real*)*max_n_frames);
  log_alpha = (real**)xalloc(sizeof(real*)*max_n_frames);
  log_beta = (real**)xalloc(sizeof(real*)*max_n_frames);
  arg_viterbi = (int**)xalloc(sizeof(int*)*max_n_frames);
  viterbi_sequence = (int*)xalloc(sizeof(int)*max_n_frames);
  word_sequence = (int*)xalloc(sizeof(int)*max_n_frames);
  for (int i=0;i<max_n_frames;i++) {
    log_probabilities_s[i] = (real*)xalloc(sizeof(real)*max_n_states);
    log_alpha[i] = (real*)xalloc(sizeof(real)*max_n_states);;
    log_beta[i] = (real*)xalloc(sizeof(real)*max_n_states);;
    arg_viterbi[i] = (int*)xalloc(sizeof(int)*max_n_states);;
  }
  states = (Distribution**)xalloc(sizeof(Distribution*)*max_n_states);
  states_to_model_states = (int*)xalloc(sizeof(int)*max_n_states);
  states_to_model = (int*)xalloc(sizeof(int)*max_n_states);
  states_to_word = (int*)xalloc(sizeof(int)*max_n_states);
  log_transitions = (real**)xalloc(sizeof(real*)*max_n_states);
  word_transitions = (bool**)xalloc(sizeof(bool*)*max_n_states);
  for (int i=0;i<max_n_states;i++) {
    log_transitions[i] = (real*)xalloc(sizeof(real)*max_n_states);;
    word_transitions[i] = (bool*)xalloc(sizeof(bool)*max_n_states);;
  }
}

void SpeechHMM::freeMemory()
{
  if (is_free)
    return;
  is_free = true;
  freeList(&outputs,true);
  freeList(&params,false);
  freeList(&der_params,false);
  for (int i=0;i<max_n_states;i++) {
    free(log_transitions[i]);
    free(word_transitions[i]);
  }
  for (int i=0;i<max_n_frames;i++) {
    free(log_probabilities_s[i]);
    free(log_alpha[i]);
    free(log_beta[i]);
    free(arg_viterbi[i]);
  }
  free(states);
  free(states_to_model_states);
  free(states_to_model);
  free(states_to_word);
  free(log_transitions);
  free(word_transitions);
  free(log_alpha);
  free(log_beta);
  free(arg_viterbi);
  free(viterbi_sequence);
  free(word_sequence);
  free(log_probabilities_s);
}

int SpeechHMM::numberOfParams()
{
  int n = 0;
  for (int i=0;i<n_models;i++)
    n += models[i]->numberOfParams();
  return n;
}

void SpeechHMM::reset()
{
  // initialize model
  // if alignment information is given in the dataset, use it.
  // otherwise, do a linear alignment along the states

  int* selected_frames = (int*)xalloc(sizeof(int)*max_n_frames);
  int n_selected_frames = 0;
  int* selected_examples = (int*)xalloc(sizeof(int)*data->n_examples);
  int n_selected_examples = 0;
  for (int m=0;m<n_models;m++) {
    // first the emissions
    for (int l=1;l<models[m]->n_states-1;l++) {
      n_selected_examples = 0;
      for (int i=0;i<data->n_examples;i++) {
        data->setExample(i);
        SeqExample *ex = (SeqExample*)data->inputs->ptr;
        // compute the number of states of this example
        // and check if the current model m is used as target
        n_states = 0;
        bool model_is_used = false;
        for (int j=0;j<ex->n_seqtargets;j++) {
          int word = (int)ex->seqtargets[j][0];
          for (int k=0;k<dict->word_length[word];k++) {
            n_states += models[dict->words[word][k]]->n_states-2;
            if (dict->words[word][k] == m)
              model_is_used = true;
          }
        }
        if (!model_is_used)
          continue;
        n_selected_frames = 0;
        if (ex->n_alignments > 0) {
          for (int j=0;j<ex->n_alignments;j++) {
            if (ex->alignment_phoneme[j] == m) {
              real start_align = j == 0 ? 0 : (real)ex->alignment[j-1];
              real end_align = (real)ex->alignment[j];
              real n_align = (end_align-start_align)/(models[m]->n_states-2);
              start_align += ((l-1) * n_align);
              if (start_align + n_align < end_align && l != models[m]->n_states-1)
                end_align = start_align + n_align;
              for (int n=(int)rint(start_align);n<(int)rint(end_align);n++) {
                selected_frames[n_selected_frames++] = n;
              }
            }
          }
        } else {
          // do a linear alignment for initialization
          int n_frames_per_state = ex->n_frames / n_states;
          int current_n_frames = 0;
          for (int j=0;j<ex->n_seqtargets;j++) {
            int word = (int)ex->seqtargets[j][0];
            for (int k=0;k<dict->word_length[word];k++) {
              if (dict->words[word][k] == m) {
                for (int n=0;n<n_frames_per_state;n++) {
                  selected_frames[n_selected_frames++] = current_n_frames +
                    (l-1)*n_frames_per_state+n;
                }
              }
              current_n_frames += n_frames_per_state * 
                (models[dict->words[word][k]]->n_states-2);
            }
          }
        }
        if (n_selected_frames>0) {
          data->setSelectedFrames(selected_frames,n_selected_frames);
          selected_examples[n_selected_examples++] = i;
        }
      }
      if (n_selected_examples>0) {
        data->pushSubset(selected_examples,n_selected_examples);
        message("initializing state %d of model %d with proper data",l,m);
        models[m]->states[l]->reset();
        data->unsetAllSelectedFrames();
        data->popSubset();
      } else {
        message("initializing state %d of model %d with all data",l,m);
        models[m]->states[l]->reset();
      }
    }
    // then the transitions
    for (int i=0;i<models[m]->n_states;i++) {
      real *p = models[m]->transitions[i];
      real *lp = models[m]->log_transitions[i];
      for (int j=0;j<models[m]->n_states;j++,lp++,p++) {
        if (*p > 0)
          *lp = log(*p);
        else 
          *lp = LOG_ZERO; 
      } 
    }
    // finally, eventually train separately each model using linear or
    // given segmentation
    if (model_trainer) {
      message("initialize model %d separately with alignment",m);
      PhonemeSeqDataSet psd(data,m,dict,models);
      psd.init();
      model_trainer->machine = models[m];
      model_trainer->distribution = models[m];
      model_trainer->data = &psd;
      model_trainer->sdata = &psd;
      model_trainer->train(NULL);
    }
  }
  free(selected_frames);
  free(selected_examples);
}

void SpeechHMM::addConnectionsBetweenWordsToModel(int word,int next_word, int current_state,int next_current_state, real log_n_next)
{
  int n_states_word = nStatesInWord(word);
  int current_model = dict->words[word][dict->word_length[word]-1];
  int n_states_model = models[current_model]->n_states;
  int next_model = dict->words[next_word][0];
  int n_states_next_model = models[next_model]->n_states;
  for (int j=1;j<n_states_model;j++) {
    if (models[current_model]->log_transitions[n_states_model-1][j] != LOG_ZERO) {
      for (int k=1;k<n_states_next_model-1;k++) {
        if (models[next_model]->log_transitions[k][0] != LOG_ZERO)
          log_transitions[next_current_state+k-1][current_state+n_states_word-n_states_model+1+j] = 
            models[current_model]->log_transitions[n_states_model-1][j] +
            models[next_model]->log_transitions[k][0] - log_n_next;
          word_transitions[next_current_state+k-1][current_state+n_states_word-n_states_model+1+j] = true;
      }
    }
  }
}

int SpeechHMM::addWordToModel(int word, int current_state)
{
  for (int l=0;l<dict->word_length[word];l++) {
    int current_model = dict->words[word][l];
    // for each emitting state of the current model
    int n_states_model = models[current_model]->n_states;
    for (int j=1;j<n_states_model-1;j++,current_state++) {
      states[current_state] = models[current_model]->states[j];
      states_to_model_states[current_state] = j;
      states_to_model[current_state] = current_model;
      states_to_word[current_state] = word;
      // for each transition from current_state
      for (int k=1;k<n_states_model-1;k++) {
        log_transitions[current_state+k-j][current_state] = 
          models[current_model]->log_transitions[k][j];
      }
    }
    // add transitions between phonemes
    if (l<dict->word_length[word]-1) {
      int next_model = dict->words[word][l+1];
      int n_states_next_model = models[next_model]->n_states;
      for (int j=1;j<n_states_model;j++) {
         if (models[current_model]->log_transitions[n_states_model-1][j] != LOG_ZERO) {
         for (int k=1;k<n_states_next_model-1;k++) {
            if (models[next_model]->log_transitions[k][0] != LOG_ZERO)
              log_transitions[current_state+k-1][current_state-n_states_model+1+j] = 
                models[current_model]->log_transitions[n_states_model-1][j] +
                models[next_model]->log_transitions[k][0];
          }
        }
      }
    }
  }
  return current_state;
}

void SpeechHMM::prepareTrainModel(List* inputs)
{
  SeqExample* ex = (SeqExample*)inputs->ptr;
  // create the new transition matrix, based on the models and the target sentence
  n_states = 2;
  for (int j=0;j<ex->n_seqtargets;j++) {
    int word = (int)ex->seqtargets[j][0];
    for (int k=0;k<dict->word_length[word];k++)
      n_states += models[dict->words[word][k]]->n_states-2;
  }
  // first realloc if necessary
  int n_frames = ex->n_frames+2;
  realloc(n_frames,n_states);

  // then put all transitions to 0
  for (int i=0;i<n_states;i++) {
    for (int j=0;j<n_states;j++) {
      log_transitions[i][j] = LOG_ZERO;
    }
  }

  // the transitions from the initial state
  int word = (int)ex->seqtargets[0][0];
  int current_model = dict->words[word][0];
  states_to_model_states[0] = 0;
  states_to_model[0] = current_model;
  states_to_word[0] = word;
  states[0] = NULL;
  states[n_states-1] = NULL;
  for (int j=1;j<models[current_model]->n_states;j++)
    log_transitions[j][0] = models[current_model]->log_transitions[j][0];
  int current_state = 1;
  for (int i=0;i<ex->n_seqtargets;i++) {
    word = (int)ex->seqtargets[i][0];
    int next_current_state = addWordToModel(word, current_state);
    if (i<ex->n_seqtargets-1) {
      int next_word = (int)ex->seqtargets[i+1][0];
      // add transitions between words
      addConnectionsBetweenWordsToModel(word,next_word,current_state,
        next_current_state,LOG_ONE);
    } else {
      // add last transitions
      current_model = dict->words[word][dict->word_length[word]-1];
      int n_states_in_model = models[current_model]->n_states;
      for (int j=1;j<n_states_in_model-1;j++)
        log_transitions[next_current_state][next_current_state-n_states_in_model+1+j] = models[current_model]->log_transitions[n_states_in_model-1][j];
    }
    current_state = next_current_state;
  }
}

int SpeechHMM::nStatesInWord(int word)
{
  int word_n_states=0;
  for (int j=0;j<dict->word_length[word];j++) {
    word_n_states += models[dict->words[word][j]]->n_states - 2;
  }
  return word_n_states;
}

int SpeechHMM::nStatesInGrammar()
{
  int grammar_n_states=2;
  for (int i=1;i<grammar->n_words-1;i++) {
    int word = grammar->words[i];
    grammar_n_states += nStatesInWord(word);
  }
  return grammar_n_states;
}

void SpeechHMM::realloc(int n_frames, int n_states_)
{
  n_states = n_states_;
  if (n_frames > max_n_frames) {
    int old_max = max_n_frames;
    max_n_frames = n_frames;
    log_probabilities_s = (real**)xrealloc(log_probabilities_s,sizeof(real*)*max_n_frames);
    log_alpha = (real**)xrealloc(log_alpha,sizeof(real*)*max_n_frames);
    log_beta = (real**)xrealloc(log_beta,sizeof(real*)*max_n_frames);
    arg_viterbi = (int**)xrealloc(arg_viterbi,sizeof(int*)*max_n_frames);
    viterbi_sequence = (int*)xrealloc(viterbi_sequence,sizeof(int)*max_n_frames);
    word_sequence = (int*)xrealloc(word_sequence,sizeof(int)*max_n_frames);
    for (int i=old_max;i<max_n_frames;i++) {
      log_probabilities_s[i] = (real*)xalloc(sizeof(real)*n_states);
      log_alpha[i] = (real*)xalloc(sizeof(real)*n_states);
      log_beta[i] = (real*)xalloc(sizeof(real)*n_states);
      arg_viterbi[i] = (int*)xalloc(sizeof(int)*n_states);
    }
  }
  if (n_states > max_n_states) {
    int old_max = max_n_states;
    max_n_states = n_states;
    states = (Distribution**)xrealloc(states,sizeof(Distribution*)*max_n_states);
    states_to_model_states = (int*)xrealloc(states_to_model_states,sizeof(int)*max_n_states);
    states_to_model = (int*)xrealloc(states_to_model,sizeof(int)*max_n_states);
    states_to_word = (int*)xrealloc(states_to_word,sizeof(int)*max_n_states);
    for (int i=0;i<max_n_frames;i++) {
      log_probabilities_s[i] = (real*)xrealloc(log_probabilities_s[i],sizeof(real)*max_n_states);
      log_alpha[i] = (real*)xrealloc(log_alpha[i],sizeof(real)*max_n_states);
      log_beta[i] = (real*)xrealloc(log_beta[i],sizeof(real)*max_n_states);
      arg_viterbi[i] = (int*)xrealloc(arg_viterbi[i],sizeof(int)*max_n_states);
    }
    log_transitions = (real**)xrealloc(log_transitions,sizeof(real*)*max_n_states);
    word_transitions = (bool**)xrealloc(word_transitions,sizeof(bool*)*max_n_states);
    for (int i=0;i<old_max;i++) {
      log_transitions[i] = (real*)xrealloc(log_transitions[i],sizeof(real)*max_n_states);;
      word_transitions[i] = (bool*)xrealloc(word_transitions[i],sizeof(bool)*max_n_states);;
    }
    for (int i=old_max;i<max_n_states;i++) {
      log_transitions[i] = (real*)xalloc(sizeof(real)*max_n_states);;
      word_transitions[i] = (bool*)xalloc(sizeof(bool)*max_n_states);;
    }
  }
}

void SpeechHMM::prepareTestModel(List* inputs)
{
  SeqExample* ex = (SeqExample*)inputs->ptr;
  // create the new transition matrix, based on the models and the target sentence
  // first realloc if necessary
  int n_frames = ex->n_frames+2;
  n_states = nStatesInGrammar();
  realloc(n_frames,n_states);

  // then put all transitions to 0
  for (int i=0;i<n_states;i++) {
    for (int j=0;j<n_states;j++) {
      log_transitions[i][j] = LOG_ZERO;
      word_transitions[i][j] = false;
    }
  }

  // then create the new transition matrix
  states_to_word[0] = -1;
  states_to_model_states[0] = -1;
  states_to_model[0] = -1;
  states[0] = NULL;
  states[n_states-1] = NULL;
  // the transitions from the initial state
  // first count the total probability, and normalize it
  real total_prob = 0;
  for (int i=1;i<grammar->n_words-1;i++) {
    if (grammar->transitions[i][0]) {
      int m = grammar->words[i];
      int p = dict->words[m][0];
      for (int k=1;k<models[p]->n_states;k++)
        total_prob += exp(models[p]->log_transitions[k][0]);
    }
  }
  // then update the probabilities
  int j=1;
  real log_total_prob = log(total_prob);
  for (int i=1;i<grammar->n_words-1;i++) {
    if (grammar->transitions[i][0]) {
      int m = grammar->words[i];
      int p = dict->words[m][0];
      for (int k=1;k<models[p]->n_states;k++)
        log_transitions[j+k-1][0] = models[p]->log_transitions[k][0] - 
          log_total_prob;
      j += nStatesInWord(m);
    }
  }

  //then, for each word in the grammar, add it
  int current_state = 1;
  for (int i=1;i<grammar->n_words-1;i++) {
    grammar->start[i] = current_state;
    current_state = addWordToModel(grammar->words[i],current_state);
  }
  // then add the transitions between words
  for (int i=1;i<grammar->n_words-1;i++) {
    int word = grammar->words[i];
    // count the transitions starting from word
    real log_n_transitions = 0;
    for (int j=1;j<grammar->n_words;j++) {
      log_n_transitions += (grammar->transitions[j][i]);
    }
    log_n_transitions = log_n_transitions>0 ? log(log_n_transitions) : LOG_ONE;
    for (int j=1;j<grammar->n_words;j++) {
      if (grammar->transitions[j][i]) {
        int next_word = grammar->words[j];
        if (next_word != -1) {
          // add transitions between words
          addConnectionsBetweenWordsToModel(word,next_word,grammar->start[i],
            grammar->start[j],log_n_transitions);
        } else {
          // add last transitions
          int current_model = dict->words[word][dict->word_length[word]-1];
          int n_states_in_model = models[current_model]->n_states;
          int n_states_in_word = nStatesInWord(word);
          int last_state = n_states-1;
          for (int k=1;k<n_states_in_model-1;k++)
            log_transitions[last_state][grammar->start[i]+n_states_in_word-n_states_in_model+2+k-1] = models[current_model]->log_transitions[n_states_in_model-1][k];
        }
      }
    }
  }
/*
  printTransitions(false,true);
  for (int i=0;i<n_states;i++) {
    printf("state %d corresponds to state %d in model %d in word %d\n",i,states_to_model_states[i],states_to_model[i],states_to_word[i]);
  }
*/
}

void SpeechHMM::eMSequenceInitialize(List* inputs)
{
  // propagate to each model
  for (int i=0;i<n_models;i++)
    models[i]->eMSequenceInitialize(inputs);

  prepareTrainModel(inputs);
}

void SpeechHMM::sequenceInitialize(List* inputs)
{
  // propagate to each model
  for (int i=0;i<n_models;i++)
    models[i]->sequenceInitialize(inputs);

  prepareTrainModel(inputs);
}

void SpeechHMM::eMIterInitialize()
{
  for (int i=0;i<n_models;i++)
    models[i]->eMIterInitialize();
}

void SpeechHMM::iterInitialize()
{
  for (int i=0;i<n_models;i++)
    models[i]->iterInitialize();
}

void SpeechHMM::eMAccPosteriors(List *inputs, real log_posterior)
{
  SeqExample* ex = (SeqExample*)inputs->ptr;
  real* in = NULL;
  real* obs = NULL;

  // compute the beta by backward recursion
  logBeta(ex);

  // accumulate the emission and transition posteriors
  for (int f=0;f<ex->n_frames;f++) {
    if (ex->inputs)
      in = ex->inputs[f];
    if (ex->observations)
      obs = ex->observations[f];
    for (int i=1;i<n_states-1;i++) {
      real log_posterior_i_f = log_posterior + log_alpha[f+1][i] +
        log_beta[f+1][i] - log_probability;
      real log_emit_i = states[i]->log_probabilities[f];
      states[i]->frameEMAccPosteriors(obs,log_posterior_i_f,in,f);
      int model_to = states_to_model[i];
      int state_to = states_to_model_states[i];
      for (int j=0;j<n_states;j++) {
        if (log_transitions[i][j] == LOG_ZERO)
          continue;
        // find the real transition
        int model_from = states_to_model[j];
        int state_from = states_to_model_states[j];
        if (model_from == model_to) {
          models[model_from]->transitions_acc[state_to][state_from] +=
            exp(log_posterior + log_alpha[f][j] + log_transitions[i][j] + 
              log_emit_i + log_beta[f+1][i] - log_probability);
        } else {
          int last_state_from = models[model_from]->n_states-1;
          models[model_from]->transitions_acc[last_state_from][state_from] +=
            exp(log_posterior + log_alpha[f][j] + 
            models[model_from]->log_transitions[last_state_from][state_from] +
            log_emit_i + log_beta[f+1][i] - log_probability);
          models[model_to]->transitions_acc[state_to][0] +=
            exp(log_posterior + log_alpha[f][j] + 
            models[model_to]->log_transitions[state_to][0] +
            log_emit_i + log_beta[f+1][i] - log_probability);
        }
      }
    }
  }
  // particular case of transitions to last state
  int f = ex->n_frames;
  int i = n_states-1;
  for (int j=0;j<n_states;j++) {
    int model_from = states_to_model[j];
    int state_from = states_to_model_states[j];
    if (log_transitions[i][j] == LOG_ZERO)
      continue;
    int last_state_from = models[model_from]->n_states-1;
    models[model_from]->transitions_acc[last_state_from][state_from] +=
      exp(log_posterior + log_alpha[f][j] + 
      models[model_from]->log_transitions[last_state_from][state_from] +
      log_beta[f+1][i] - log_probability);
  }
}

void SpeechHMM::viterbiAccPosteriors(List *inputs, real log_posterior)
{
  SeqExample* ex = (SeqExample*)inputs->ptr;
  real* in = NULL;
  real* obs = NULL;

  // accumulate the emission and transition posteriors
  for (int f=0;f<ex->n_frames;f++) {
    int i = viterbi_sequence[f+1];
    int model_to = states_to_model[i];
    int state_to = states_to_model_states[i];
    if (ex->inputs)
      in = ex->inputs[f];
    if (ex->observations)
      obs = ex->observations[f];
    states[i]->frameEMAccPosteriors(obs,log_posterior,in,f);
    int j = arg_viterbi[f+1][i];
    // find the real transition
    int model_from = states_to_model[j];
    int state_from = states_to_model_states[j];
    if (model_from == model_to) {
      models[model_from]->transitions_acc[state_to][state_from] +=log_posterior;
    } else {
      int last_state_from = models[model_from]->n_states-1;
      models[model_from]->transitions_acc[last_state_from][state_from] +=
        log_posterior;
      models[model_to]->transitions_acc[state_to][0] += log_posterior;
    }
  }
}

void SpeechHMM::eMUpdate()
{
  // for each model
  for (int i=0;i<n_models;i++) {
    models[i]->eMUpdate();
  }
}

void SpeechHMM::backward(List *inputs, real *alpha)
{
  SeqExample* ex = (SeqExample*)inputs->ptr;
  real* in = NULL;
  real* obs = NULL;

  // compute the beta by backward recursion
  logBeta(ex);

  // accumulate the emission and transition posteriors
  for (int f=0;f<ex->n_frames;f++) {
    if (ex->inputs)
      in = ex->inputs[f];
    if (ex->observations)
      obs = ex->observations[f];
    for (int i=1;i<n_states-1;i++) {
      real posterior_i_f[1];
      posterior_i_f[0] = - *alpha * exp(log_alpha[f+1][i] + 
        log_beta[f+1][i] - log_probability);
      real log_emit_i = states[i]->log_probabilities[f];
      states[i]->frameBackward(obs,posterior_i_f,in,f);
      int model_to = states_to_model[i];
      int state_to = states_to_model_states[i];
      for (int j=0;j<n_states;j++) {
        if (log_transitions[i][j] == LOG_ZERO)
          continue;
        int model_from = states_to_model[j];
        int state_from = states_to_model_states[j];
        if (model_from == model_to) {
          real posterior_i_j_f = - *alpha * exp(log_alpha[f][j] +
            log_transitions[i][j] + log_emit_i + log_beta[f+1][i] - 
            log_probability);
          models[model_from]->dlog_transitions[state_to][state_from] +=
            posterior_i_j_f;
          for (int k=0;k<n_states;k++) {
            if (log_transitions[k][j] == LOG_ZERO)
              continue;
            models[model_from]->dlog_transitions[state_to][state_from] -=
              posterior_i_j_f * exp(log_transitions[k][j]);
          }
        } else {
          int last_state_from = models[model_from]->n_states-1;
          real posterior_i_j_f_from = - *alpha * exp(log_alpha[f][j] +
            models[model_from]->log_transitions[last_state_from][state_from] + 
            log_emit_i + log_beta[f+1][i] - log_probability);
          real posterior_i_j_f_to = - *alpha * exp(log_alpha[f][j] +
            models[model_to]->log_transitions[state_to][0] + 
            log_emit_i + log_beta[f+1][i] - log_probability);
          models[model_from]->dlog_transitions[last_state_from][state_from] +=
            posterior_i_j_f_from;
          models[model_to]->transitions_acc[state_to][0] +=
            posterior_i_j_f_to;
          for (int k=0;k<n_states;k++) {
            if (log_transitions[k][j] == LOG_ZERO)
              continue;
            models[model_from]->dlog_transitions[last_state_from][state_from] -=
              posterior_i_j_f_from * exp(models[model_from]->log_transitions[last_state_from][state_from]);
            models[model_to]->dlog_transitions[state_to][0] -=
              posterior_i_j_f_to * exp(models[model_to]->log_transitions[state_to][0]);
          }
        }
      }
    }
  }
  // particular case of transitions to last state
  int f = ex->n_frames;
  int i = n_states-1;
  for (int j=0;j<n_states;j++) {
    if (log_transitions[i][j] == LOG_ZERO)
      continue;
    int model_from = states_to_model[j];
    int state_from = states_to_model_states[j];
    int last_state_from = models[model_from]->n_states-1;
    real posterior_i_j_f = - *alpha * exp(log_alpha[f][j] +
      models[model_from]->log_transitions[last_state_from][state_from] + 
      log_beta[f+1][i] - log_probability);
    models[model_from]->log_transitions[last_state_from][state_from] +=
      posterior_i_j_f;
    for (int k=0;k<n_states;k++) {
      if (log_transitions[k][j] == LOG_ZERO)
        continue;
      models[model_from]->dlog_transitions[last_state_from][state_from] -=
        posterior_i_j_f * exp(models[model_from]->log_transitions[last_state_from][state_from]);
    }
  }
}

void SpeechHMM::logViterbi(SeqExample* ex)
{
  log_alpha[0][0] = LOG_ONE;
  for (int i=1;i<n_states;i++)
     log_alpha[0][i] = LOG_ZERO;
  for (int f=1;f<=ex->n_frames;f++) {
    log_alpha[f][0] = LOG_ZERO;
    log_alpha[f][n_states-1] = LOG_ZERO;
    for (int i=1;i<n_states-1;i++) {
      log_alpha[f][i] = LOG_ZERO;
      for (int j=0;j<n_states-1;j++) {
        if (log_transitions[i][j] == LOG_ZERO)
          continue;
        real v =
          log_transitions[i][j] + log_probabilities_s[f][i] + log_alpha[f-1][j];
        if (word_transitions[i][j]) {
          v += word_entrance_penalty;
        }
        if (log_alpha[f][i] <= v) {
          log_alpha[f][i] = v;
          arg_viterbi[f][i] = j;
        }
      }
    }
  }
  // last frame
  int f=ex->n_frames+1;
  for (int j=0;j<n_states;j++)
    log_alpha[f][j] = LOG_ZERO;
  int i=n_states-1;
  for (int j=1;j<n_states-1;j++) {
    if (log_transitions[i][j] == LOG_ZERO)
      continue;
    real v = log_transitions[i][j] + log_alpha[f-1][j];
    if (log_alpha[f][i] < v) {
      log_alpha[f][i] = v;
      arg_viterbi[f][i] = j;
    }
  }
  // now recall the state sequence
  viterbi_sequence[ex->n_frames+1] = n_states-1;
  for (int f=ex->n_frames;f>=0;f--) {
    viterbi_sequence[f] = arg_viterbi[f+1][viterbi_sequence[f+1]];
  }
}


void SpeechHMM::decode(List* input)
{
  SeqExample* ex = (SeqExample*)input->ptr;
  for (int i=0;i<n_models;i++)
    models[i]->eMSequenceInitialize(input);
  prepareTestModel(input);
  logProbabilities(input);
  logViterbi(ex);

  // convert the state sequence to a word sequence
  word_sequence_size = 0;
  int previous_state = -1;
  for (int i=1;i<ex->n_frames;i++) {
    int state = viterbi_sequence[i];
    int word = states_to_word[state];
    // do not keep silences and register each time we exit a model
    if (word != dict->silence_word) {
      if ((previous_state == -1) ||
          (previous_state>=0 && word_transitions[state][previous_state])) 
        word_sequence[word_sequence_size++] = word;
    }
    previous_state = state;
  }

  // keep in memory the target word sequence
  if (ex->n_seqtargets > target_word_sequence_max_size) {
    target_word_sequence = 
      (int*)xrealloc(target_word_sequence,ex->n_seqtargets*sizeof(int));
    target_word_sequence_max_size =target_word_sequence_size = ex->n_seqtargets;
  } else {
    target_word_sequence_size = ex->n_seqtargets;
  }
  int j=0;
  for (int i=0;i<ex->n_seqtargets;i++) {
    int word = (int)ex->seqtargets[i][0];
    if (word != dict->silence_word) {
      target_word_sequence[j++] = word;
    }
  }
  target_word_sequence_size = j;
  
  // then compute edit distance in phoneme space
  if (target_word_sequence_size > 0) {
    edit_distance->reset();
    edit_distance->distance(word_sequence,word_sequence_size,
      target_word_sequence,target_word_sequence_size);
  }

  *(real*)outputs->ptr = edit_distance->accuracy;
}

SpeechHMM::~SpeechHMM()
{
  delete edit_distance;
  if (target_word_sequence_max_size > 0)
    free(target_word_sequence);
  freeMemory();
}

}

