classdef ShortRate < OptimProblem
  properties 
    name          = 'ShortRate';
    
    calendar      = [];         % (Optional) calendar associated with the above series
    series        = [];       
    transfseries  = [];         % Transformed series
    
    Monetary      = [];         % A placeholder for the Monetary object
    Inflation     = [];         % A placeholder for the Inflation object
    
    cond_dist     = 'Normal';   % The (conditional) distribution of the innovations. 
    method        = 'MLE';      % The estimation method
      
    rbar          =  0.0050;    % Interest rate transformation parameters
    c             =  0;         % Interest rate transformation parameters

    model_restr   =  1;         % 1 for Wilkie-like, 2 for ExtendedWilkie-like, 3 for ADG-like, 4 for full ESG
    
    %%#####  INTERNAL  ############################################################
    mle  = struct();           % A placeholder for MLE results
  end % end properties
    
  methods 
    %% Construct a ShortRate model 
    function self = ShortRate(varargin)
     self = self@OptimProblem();
      for no = 1:2:length(varargin)
        setfield(self, varargin{no}, varargin{no+1});
      end      

      % ShortRate parameters
      self.addParameter('mu_r1',            0.4210,       [-5.000,  5.000]);
      self.addParameter('mu_r2',           -0.0193,       [-5.000,  5.000]);
      self.addParameter('mu_r3',           -0.6672,       [-5.000,  5.000]);
      self.addParameter('a_r',              0.990,        [-0.995,  0.995]);
      self.addParameter('rho_qr',          -0.0028,       [-1.000,  1.000]);
      self.addParameter('sig2_r',           2.0006e-05,   [ 1e-9,   0.005]);
      self.addParameter('alpha_r',          0.3407,       [ 0.00,  1.0000]);
      self.addParameter('alphabetagamma_r', 0.9755,       [ 0.00,  0.9950]);
      self.addParameter('gamma_r',          0,            [  -10,      10]);
      self.addParameter('r0',               0.0714,       [ 0.00001, 1.00]);
      self.addParameter('sig2_r_init',      0.0015,       [ 0.00,   0.005]);

      % Depending on the model selected, we turn on or off some parameters
      if self.model_restr == 1
        self.params.mu_r2.fixed             = true;
        self.params.mu_r3.fixed             = true;
        self.params.alpha_r.fixed           = true;
        self.params.alphabetagamma_r.fixed  = true;
        self.params.gamma_r.fixed           = true;
        self.params.sig2_r_init.fixed       = true;
        
        self.params.alpha_r.value           = 0;
        self.params.alphabetagamma_r.value  = 0;
        self.params.gamma_r.value           = 0;
        self.params.sig2_r_init.value       = self.params.sig2_r.value;
        self.params.mu_r2.value             = self.params.mu_r1.value;
        self.params.mu_r3.value             = self.params.mu_r1.value;
      elseif self.model_restr == 2
        self.params.mu_r2.fixed             = true;
        self.params.mu_r3.fixed             = true;
        self.params.alpha_r.fixed           = true;
        self.params.alphabetagamma_r.fixed  = true;
        self.params.gamma_r.fixed           = true;
        self.params.sig2_r_init.fixed       = true;

        self.params.alpha_r.value           = 0;
        self.params.alphabetagamma_r.value  = 0;
        self.params.gamma_r.value           = 0;
        self.params.sig2_r_init.value       = self.params.sig2_r.value;
        self.params.mu_r2.value             = self.params.mu_r1.value;
        self.params.mu_r3.value             = self.params.mu_r1.value;
      elseif self.model_restr == 3
        self.params.mu_r2.fixed             = true;
        self.params.mu_r3.fixed             = true;
        self.params.alpha_r.fixed           = true;
        self.params.alphabetagamma_r.fixed  = true;
        self.params.gamma_r.fixed           = true;
        self.params.sig2_r_init.fixed       = true;

        self.params.alpha_r.value           = 0;
        self.params.alphabetagamma_r.value  = 0;
        self.params.gamma_r.value           = 0;
        self.params.sig2_r_init.value       = self.params.sig2_r.value;
        self.params.mu_r2.value             = self.params.mu_r1.value;
        self.params.mu_r3.value             = self.params.mu_r1.value;
      end
      self.getTransformedSeries();
    end % end ShortRate
    
    %% This function returns the parameters
    function pv = getPV(self)
      pv        = getPV@OptimProblem(self);
      
      % Again, depending on the model selected, some parameters are turned
      % on or off.
      if self.model_restr == 1
        self.params.sig2_r_init.value = self.params.sig2_r.value;
        self.params.mu_r2.value       = self.params.mu_r1.value;
        self.params.mu_r3.value       = self.params.mu_r1.value;
        
        pv.sig2_r_init                = self.params.sig2_r.value;
        pv.mu_r2                      = self.params.mu_r1.value;
        pv.mu_r3                      = self.params.mu_r1.value;
      elseif self.model_restr == 2
        self.params.sig2_r_init.value = self.params.sig2_r.value;
        self.params.mu_r2.value       = self.params.mu_r1.value;
        self.params.mu_r3.value       = self.params.mu_r1.value;
        
        pv.sig2_r_init                = self.params.sig2_r.value;
        pv.mu_r2                      = self.params.mu_r1.value;
        pv.mu_r3                      = self.params.mu_r1.value;
      elseif self.model_restr == 3
        self.params.sig2_r_init.value = self.params.sig2_r.value;
        self.params.mu_r2.value       = self.params.mu_r1.value;
        self.params.mu_r3.value       = self.params.mu_r1.value;
        
        pv.sig2_r_init                = self.params.sig2_r.value;
        pv.mu_r2                      = self.params.mu_r1.value;
        pv.mu_r3                      = self.params.mu_r1.value;
      end
      
      pv.beta_r = pv.alphabetagamma_r - pv.alpha_r*(1 + pv.gamma_r^2);
    end % end getPV
    
    %% This functions uses fminsearch to find the MLE parameters
    function [results] = fminsearch(self, lambda, varargin)
      results = fminsearch@OptimProblem(self, lambda, varargin{:});     
      [~,z,sig2] = self.objective(self.getPValues,varargin{:});
      
      switch self.method
        case 'MLE'
          self.mle.params = self.getPV;
          self.mle.z    = z;
          self.mle.sig2 = sig2;
          self.mle.out = results;
      end
    end % end fminsearch

    %% This function computes the objective function; the only case considered so far is the MLE
    function [S,z,sig2] = objective(self, x, varargin)
      switch self.method
        case 'MLE'
          [logl,z,sig2] = self.getMLEfunction(x,varargin{:});
          S = -sum(logl);
      end
    end % end objective

    %% This function transforms the interest rates into their transformed versions
    function transfseries = getTransformedSeries(self)
      pv = self.getPV;
      series = [pv.r0;self.series];

      % We tranform the series of short rate according to the trasformation
      % given in the paper (similar to Engle et al., 2017)
      a0 = self.rbar - (self.rbar - self.c)*log(self.rbar - self.c);
      a1 = self.rbar - self.c;
      transfseries = series.*(series >= self.rbar) + (a0 + a1.*log(series - self.c)).*(series < self.rbar);
      self.transfseries = transfseries;
    end % end getTransformedSeries
    
    %% This function computes the likelihood function for the dividend yield model
    function [logl,z,sig2] = getMLEfunction(self,p,zq)
      % We extract the current value of the parameters
      self.setPValues(p);
      pv = self.getPV;
      
      % We extract the number of observations and the monetary policy chain
      T = length(self.series) + 1;
      m = self.Monetary.regimes; m = [floor(self.Monetary.params.m0.value);m];

      z    = zeros(T,1);
      y    = self.transfseries;
      sig2 = zeros(T+1,1);

      mus = [pv.mu_r1,pv.mu_r2,pv.mu_r3];

      % For each observation, we compute the log-likelihood contribution at
      % time t
      sig2(2) = pv.sig2_r_init;
      for dt = 2:T
        z(dt) = y(dt) - mus(m(dt)) - pv.a_r*(y(dt-1) - mus(m(dt)));
        sig2(dt+1) = pv.sig2_r + pv.beta_r*(sig2(dt)-pv.sig2_r) + pv.alpha_r*( (z(dt) - pv.gamma_r.*sqrt(sig2(dt))).^2 - (1+pv.gamma_r.^2)*pv.sig2_r);
      end
      sig2 = sig2(2:end-1);
      z    = z(2:end);
    	logl = log( normpdf( z - pv.rho_qr.*sqrt(sig2).*zq(:), 0, sqrt(sig2)*sqrt(1-pv.rho_qr^2)) );
      if imag(sum(logl)) ~= 0
        logl = NaN(size(logl));
      end
    end % end getMLEfunction
    
    %% This function computes the log prior for the dividend yield model
    function logprior = getlogPrior(self,pv)
      logprior =  log(normpdf(pv.mu_r1,0,1)) + ...
                  log(normpdf(pv.mu_r2,0,1)) + ...
                  log(normpdf(pv.mu_r3,0,1)) + ...
                  log((pv.a_r <= 0.995).*(pv.a_r >= -0.995)) + ...                			% Hack: Otherwise, we might have some stationarity issues...
                  log(betapdf((pv.rho_qr + 1)./2,1.5,1.5)) + ...                  			% Marginal of the LKJ distribution for one correlation is beta
                  log(exppdf(pv.sig2_r,1)) + ...
                  log((pv.alpha_r <= 0.995).*(pv.alpha_r >= 0)) + ...						% Hack: Otherwise, we might have some stationarity issues...
                  log((pv.alphabetagamma_r <= 0.995).*(pv.alphabetagamma_r >= 0)) + ...     % Hack: Otherwise, we might have some stationarity issues...
                  log((pv.beta_r <= 0.995).*(pv.beta_r >= 0)) + ...                         % Hack: Otherwise, we might have some stationarity issues...
                  log(normpdf(pv.gamma_r,0,10)) + ...
                  log(normpdf(pv.r0,0,1).*(pv.r0 > 0)) + ...
                  log(exppdf(pv.sig2_r_init,1));
    end % end getlogPrior
	
  end % end methods 

end % end ShortRate