classdef Inflation < OptimProblem
  properties 
    name          = 'Inflation';
    
    calendar      = [];         % (Optional) calendar associated with the above series
    series        = [];       

    Monetary      = [];         % A placeholder for the (filtered) Monetary object
    
    cond_dist     = 'Normal';   % The (conditional) distribution of the innovations 
    method        = 'MLE';      % The estimation method

    model_restr   =  1;         % 1 for Wilkie-like, 
                                % 2 for Extended Wilkie-like, 
                                % 3 for ADG-like, 
                                % 4 for full new ESG  
                                
    %%#####  INTERNAL  ############################################################
    mle  = struct();            % A placeholder for MLE results
  end % end properties
    
  methods 
    %% Construct a Inflation model 
    function self = Inflation(varargin)
     self = self@OptimProblem();
      for no = 1:2:length(varargin)
        setfield(self, varargin{no}, varargin{no+1});
      end      
      
      % Inflation parameters
      self.addParameter('mu_q1',            0.0028,       [-1.000, 1.00000]);
      self.addParameter('mu_q2',            0.0022,       [-1.000, 1.00000]);
      self.addParameter('mu_q3',            0.0021,       [-1.000, 1.00000]);
      self.addParameter('a_q',              0.3477,       [-1.000, 1.00000]);
      self.addParameter('sig2_q',           5.8426e-06,   [ 0.00,  0.00010]);
      self.addParameter('alpha_q',          0.9206,       [ 0.00,  1.00000]);
      self.addParameter('alphabetagamma_q', 0.9492,       [ 0.00,  0.99000]);
      self.addParameter('gamma_q',         -0.2506,       [  -10,       10]);
      self.addParameter('q0',               0.0063,       [-1.00,  1.00000]);
      self.addParameter('sig2_q_init',      1.3981e-05,   [ 0.00,  0.00010]);
      
      % Depending on the model selected, we turn on or off some parameters
      if self.model_restr == 1
        self.params.mu_q2.fixed             = true;
        self.params.mu_q3.fixed             = true;
        self.params.alpha_q.fixed           = true;
        self.params.alphabetagamma_q.fixed  = true;
        self.params.gamma_q.fixed           = true;
        self.params.sig2_q_init.fixed       = true;

        self.params.alpha_q.value           = 0;
        self.params.alphabetagamma_q.value  = 0;
        self.params.gamma_q.value           = 0;
        self.params.sig2_q_init.value       = self.params.sig2_q.value;
        self.params.mu_q2.value             = self.params.mu_q1.value;
        self.params.mu_q3.value             = self.params.mu_q1.value;
      elseif self.model_restr == 2
        self.params.mu_q2.fixed             = true;
        self.params.mu_q3.fixed             = true;
        self.params.alpha_q.fixed           = true;
        self.params.alphabetagamma_q.fixed  = true;
        self.params.gamma_q.fixed           = true;
        self.params.sig2_q_init.fixed       = true;

        self.params.alpha_q.value           = 0;
        self.params.alphabetagamma_q.value  = 0;
        self.params.gamma_q.value           = 0;
        self.params.sig2_q_init.value       = self.params.sig2_q.value;
        self.params.mu_q2.value             = self.params.mu_q1.value;
        self.params.mu_q3.value             = self.params.mu_q1.value;
      elseif self.model_restr == 3
        self.params.mu_q2.fixed             = true;
        self.params.mu_q3.fixed             = true;
        self.params.alpha_q.fixed           = true;
        self.params.alphabetagamma_q.fixed  = true;
        self.params.gamma_q.fixed           = true;
        self.params.sig2_q_init.fixed       = true;

        self.params.alpha_q.value           = 0;
        self.params.alphabetagamma_q.value  = 0;
        self.params.gamma_q.value           = 0;
        self.params.sig2_q_init.value       = self.params.sig2_q.value;
        self.params.mu_q2.value             = self.params.mu_q1.value;
        self.params.mu_q3.value             = self.params.mu_q1.value;
      end
    end % end Inflation
    
    %% This function returns the parameters
    function pv = getPV(self)
      pv        = getPV@OptimProblem(self);
      
      % Again, depending on the model selected, some parameters are turned
      % on or off.
      if self.model_restr == 1
        self.params.sig2_q_init.value = self.params.sig2_q.value;
        self.params.mu_q2.value       = self.params.mu_q1.value;
        self.params.mu_q3.value       = self.params.mu_q1.value;
        
        pv.sig2_q_init                = self.params.sig2_q.value;
        pv.mu_q2                      = self.params.mu_q1.value;
        pv.mu_q3                      = self.params.mu_q1.value;
      elseif self.model_restr == 2
        self.params.sig2_q_init.value = self.params.sig2_q.value;
        self.params.mu_q2.value       = self.params.mu_q1.value;
        self.params.mu_q3.value       = self.params.mu_q1.value;
        
        pv.sig2_q_init                = self.params.sig2_q.value;
        pv.mu_q2                      = self.params.mu_q1.value;
        pv.mu_q3                      = self.params.mu_q1.value;
      elseif self.model_restr == 3
        self.params.sig2_q_init.value = self.params.sig2_q.value;
        self.params.mu_q2.value       = self.params.mu_q1.value;
        self.params.mu_q3.value       = self.params.mu_q1.value;
        
        pv.sig2_q_init                = self.params.sig2_q.value;
        pv.mu_q2                      = self.params.mu_q1.value;
        pv.mu_q3                      = self.params.mu_q1.value;
      end

      pv.beta_q = pv.alphabetagamma_q - pv.alpha_q*(1+pv.gamma_q^2);
    end % end getPV

    %% This functions uses fminsearch to find the MLE parameters
    function [results] = fminsearch(self, lambda, varargin)
      results = fminsearch@OptimProblem(self, lambda, varargin);     
      [~,z,sig2] = self.objective(self.getPValues);
      
      switch self.method
        case 'MLE'
          self.mle.params = self.getPV;
          self.mle.z      = z;
          self.mle.sig2   = sig2;
          self.mle.out    = results;
        otherwise
          error('Method not implemented yet.');
      end
    end % end fminsearch

    %% This function computes the objective function; the only case considered so far is the MLE
    function [S,z,sig2] = objective(self, x, varargin)
      switch self.method
        case 'MLE'
          [logl,z,sig2] = self.getMLEfunction(x);
          S = -sum(logl);
      end
    end % end objective

    %% This function computes the likelihood function for the inflation model
    function [logl,z,sig2] = getMLEfunction(self,p)
      % We extract the current value of the parameters
      self.setPValues(p);
      pv = self.getPV;
      
      % We extract the number of observations and the monetary policy chain
      T = length(self.series) + 1;
      m = self.Monetary.regimes; m = [floor(self.Monetary.params.m0.value);m];

      z    = zeros(T,1);
      sig2 = zeros(T+1,1);
      y = [pv.q0; self.series];

      mus = [pv.mu_q1,pv.mu_q2,pv.mu_q3];
      
      % For each observation, we compute the log-likelihood contribution at
      % time t
      sig2(2) = pv.sig2_q_init;
      for dt = 2:T
        z(dt) = y(dt) - mus(m(dt)) - pv.a_q*(y(dt-1) - mus(m(dt)));
        sig2(dt+1) = pv.sig2_q + pv.beta_q*(sig2(dt)-pv.sig2_q) + pv.alpha_q*( (z(dt) - pv.gamma_q.*sqrt(sig2(dt))).^2 - (1+pv.gamma_q.^2)*pv.sig2_q);
      end
      sig2 = sig2(2:end-1);
      z    = z(2:end);
    	logl = log( normpdf( z, 0, sqrt(sig2)) );
     
      % Hack: if a loglikelihood is not real, we return an array of NaN
      if imag(sum(logl)) ~= 0
        logl = NaN(size(logl));
      end
     end % end getMLEfunction
    
    %% This function computes the log prior for the inflation model
    function logprior = getlogPrior(self,pv)
      logprior =  log(normpdf(pv.mu_q1,0,1)) + ...
                  log(normpdf(pv.mu_q2,0,1)) + ...
                  log(normpdf(pv.mu_q3,0,1)) + ...
                  log((pv.a_q <= 0.995).*(pv.a_q >= -0.995)) + ...                  % Hack: Otherwise, we might have some stationarity issues...
                  log(exppdf(pv.sig2_q,1)) + ...
                  log((pv.alpha_q <= 0.995).*(pv.alpha_q >= 0)) + ...                         % Hack: Otherwise, we might have some stationarity issues...
                  log((pv.alphabetagamma_q <= 0.995).*(pv.alphabetagamma_q >= 0)) + ...       % Hack: Otherwise, we might have some stationarity issues...
                  log((pv.beta_q <= 0.995).*(pv.beta_q >= 0)) + ...                           % Hack: Otherwise, we might have some stationarity issues...
                  log(normpdf(pv.gamma_q,0,10)) + ...
                  log(normpdf(pv.q0,0,1)) + ...
                  log(exppdf(pv.sig2_q_init,1));
    end % end getlogPrior
    
  end % end methods 

end % end Inflation