// Copyright (C) 2002 Samy Bengio (bengio@idiap.ch)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA


#include "DistrMachine.h"
#include "log_add.h"

namespace Torch {

DistrMachine::DistrMachine(Distribution* distribution_,GradientMachine *machine_)
{
  distribution = distribution_;
  machine = machine_;
  params = NULL;
  der_params = NULL;
  outputs = NULL;
  n_outputs = distribution->n_outputs;
  n_inputs = machine->n_inputs;
  input_machine.n = n_inputs;
  input_machine.next = NULL;
}

void DistrMachine::loadFILE(FILE *file)
{
  // the only parameters are stored in the machine
  machine->loadFILE(file);
}

void DistrMachine::saveFILE(FILE *file)
{
  // the only parameters are stored in the machine
  machine->saveFILE(file);
}

void DistrMachine::allocateMemory()
{
  max_n_frames = distribution->max_n_frames;
  der_params_distribution = (real*)xalloc(sizeof(real)*distribution->n_params);
  addToList(&params,machine->params);
  addToList(&der_params,machine->der_params);
  addToList(&outputs,distribution->outputs);
}

void DistrMachine::freeMemory()
{
  beta = NULL;
  free(der_params_distribution);
  freeList(&outputs,false);
  freeList(&params,false);
  freeList(&der_params,false);
}

int DistrMachine::numberOfParams()
{
  return machine->numberOfParams();
}

void DistrMachine::reset()
{
  distribution->reset();
  machine->reset();
}

real DistrMachine::frameLogProbability(real *observations, real *inputs, int t)
{
  input_machine.ptr = inputs;
  machine->forward(&input_machine);

  // here, copy outputs of machine into distribution params
  copyList(distribution->params, machine->outputs);

  // be careful: this is a hack! but is needed for some distributions
  // such as DiagonalGMM
  distribution->sequenceInitialize(NULL);
  return distribution->frameLogProbability(observations,NULL,t);
}

void DistrMachine::sequenceInitialize(List* inputs)
{
  distribution->sequenceInitialize(inputs);
}

void DistrMachine::iterInitialize()
{
  distribution->iterInitialize();
  machine->iterInitialize();
}

void DistrMachine::frameBackward(real *observations, real *alpha, real *inputs, int t)
{
  distribution->frameBackward(observations,alpha,NULL,t);

  // here, copy distribution der_params into machine->der_outputs
  copyList(der_params_distribution,distribution->der_params);

  input_machine.ptr = inputs;
  machine->backward(&input_machine,der_params_distribution);
  beta = machine->beta;
}

void DistrMachine::frameExpectation(real *observations, real *inputs, int t)
{
  input_machine.ptr = inputs;
  machine->forward(&input_machine);

  // here, copy outputs of machine into distribution params
  copyList(distribution->params, machine->outputs);

  // be careful: this is a hack! but is needed for some distributions
  // such as DiagonalGMM
  distribution->sequenceInitialize(NULL);
  distribution->frameExpectation(observations,NULL,t);
}

DistrMachine::~DistrMachine()
{
  freeMemory();
}

}

