classdef ForwardRate < OptimProblem
  properties 
    name          = 'ForwardRate';
    
    calendar      = [];         % (Optional) calendar associated with the above series
    series        = [];       
    transfseries  = [];         % Transformed series
    maturities    = [];         % Maturities of rates
    
    cond_dist     = 'Normal';   % The (conditional) distribution of the innovations. 
    method        = 'MLE';      % The estimation method
      
    fbar          =  0.0050;    % Forward rate transformation parameters
    c             =  0;         % Forward rate transformation parameters

    model_restr   =  1;         % 1 for Wilkie-like, 
                                % 2 for Extended Wilkie-like, 
                                % 3 for ADG-like, 
                                % 4 for full new ESG
                                
    %%#####  INTERNAL  ############################################################
    mle           = struct();   % A placeholder for MLE results
  end % end properties
    
  methods
    %% Construct a ForwardRate model 
    function self = ForwardRate(varargin)
     self = self@OptimProblem();
      for no = 1:2:length(varargin)
        setfield(self, varargin{no}, varargin{no+1});
      end
      
      % Parameters
      self.addParameter('mu_f1',       0.0021,       [-0.200, 0.200]);
      self.addParameter('mu_f2',       0.0049,       [-0.200, 0.200]);
      self.addParameter('mu_f3',       4.0310e-05,   [-0.200, 0.200]);
      self.addParameter('mu_f5',       0.0025,       [-0.200, 0.200]);
      self.addParameter('mu_f7',       0.0045,       [-0.200, 0.200]);
      self.addParameter('mu_f10',      3.4208e-04,   [-0.200, 0.200]);
      self.addParameter('mu_f30',      0.0030,       [-0.200, 0.200]);

      self.addParameter('sig2_f1',     1.8352e-05,   [ 0.00000001, 0.001]);
      self.addParameter('sig2_f2',     2.3891e-05,   [ 0.00000001, 0.001]);
      self.addParameter('sig2_f3',     2.2179e-05,   [ 0.00000001, 0.001]);
      self.addParameter('sig2_f5',     1.9838e-05,   [ 0.00000001, 0.001]);
      self.addParameter('sig2_f7',     2.6611e-05,   [ 0.00000001, 0.001]);
      self.addParameter('sig2_f10',    3.0773e-05,   [ 0.00000001, 0.001]);
      self.addParameter('sig2_f30',    1.7891e-05,   [ 0.00000001, 0.001]);
      
      self.addParameter('A_f11',       0.1679,       [    -2.000, 2.000]);
      self.addParameter('A_f12',       0.3895,       [    -2.000, 2.000]);
      self.addParameter('A_f13',       0.7413,       [    -2.000, 2.000]);
      self.addParameter('A_f15',       0.9089,       [    -2.000, 2.000]);
      self.addParameter('A_f17',       1.0282,       [    -2.000, 2.000]);
      self.addParameter('A_f110',      1.1707,       [    -2.000, 2.000]);
      self.addParameter('A_f130',      1.1331,       [    -2.000, 2.000]);
      
      self.addParameter('A_f21',      -0.2413,       [    -2.000, 2.000]);
      self.addParameter('A_f22',      -0.6885,       [    -2.000, 2.000]);
      self.addParameter('A_f23',      -0.7783,       [    -2.000, 2.000]);
      self.addParameter('A_f25',      -0.6837,       [    -2.000, 2.000]);
      self.addParameter('A_f27',      -0.5506,       [    -2.000, 2.000]);
      self.addParameter('A_f210',     -0.3581,       [    -2.000, 2.000]);
      self.addParameter('A_f230',     -0.2338,       [    -2.000, 2.000]);
      
      self.addParameter('mu_F1',      -0.0727,       [    -2.000, 2.000]);
      self.addParameter('mu_F2',       0.0087,       [    -2.000, 2.000]);

      self.addParameter('A_F1',        0.9950,       [ 0.000, 0.999]);
      self.addParameter('A_F2',        0.9386,       [ 0.000, 0.999]);

      self.addParameter('sig2_F1',     7.0579e-06,   [ 0.000001, 0.001]);
      self.addParameter('sig2_F2',     1.7092e-05,   [ 0.000001, 0.001]);
      
      self.addParameter('F_10',        0.0074,       [-0.200, 0.200]);
      self.addParameter('F_20',        0.0175,       [-0.200, 0.200]);

      % Depending on the model selected, we turn on or off some parameters
      if self.model_restr == 1
        self.params.A_f11.fixed    = true;
        self.params.A_f12.fixed    = true;
        self.params.A_f13.fixed    = true;
        self.params.A_f15.fixed    = true;
        self.params.A_f17.fixed    = true;
        self.params.A_f110.fixed   = true;
        self.params.A_f130.fixed   = true;
        
        self.params.A_f21.fixed    = true;
        self.params.A_f22.fixed    = true;
        self.params.A_f23.fixed    = true;
        self.params.A_f25.fixed    = true;
        self.params.A_f27.fixed    = true;
        self.params.A_f210.fixed   = true;
        self.params.A_f230.fixed   = true;
        
        self.params.A_f11.value    = 0;
        self.params.A_f12.value    = 0;
        self.params.A_f13.value    = 0;
        self.params.A_f15.value    = 0;
        self.params.A_f17.value    = 0;
        self.params.A_f110.value   = 0;
        self.params.A_f130.value   = 0;
        
        self.params.A_f21.value    = 0;
        self.params.A_f22.value    = 0;
        self.params.A_f23.value    = 0;
        self.params.A_f25.value    = 0;
        self.params.A_f27.value    = 0;
        self.params.A_f210.value   = 0;
        self.params.A_f230.value   = 0;

        self.params.mu_f1.value     = 0.0055;
        self.params.mu_f2.value     = 0.0115;
        self.params.mu_f3.value     = 0.0146;
        self.params.mu_f5.value     = 0.0195;
        self.params.mu_f7.value     = 0.0240;
        self.params.mu_f10.value    = 0.0240;
        self.params.mu_f30.value    = 0.0256;
        
        self.params.sig2_f1.value   = 2.2873e-05;
        self.params.sig2_f2.value   = 7.2221e-05;
        self.params.sig2_f3.value   = 1.0609e-04;
        self.params.sig2_f5.value   = 1.7762e-04;
        self.params.sig2_f7.value   = 2.2312e-04;
        self.params.sig2_f10.value  = 2.5857e-04;
        self.params.sig2_f30.value  = 2.5425e-04;
      elseif self.model_restr == 2
        self.params.A_f21.fixed    = true;
        self.params.A_f22.fixed    = true;
        self.params.A_f23.fixed    = true;
        self.params.A_f25.fixed    = true;
        self.params.A_f27.fixed    = true;
        self.params.A_f210.fixed   = true;
        self.params.A_f230.fixed   = true;

        self.params.A_f21.value    = 0;
        self.params.A_f22.value    = 0;
        self.params.A_f23.value    = 0;
        self.params.A_f25.value    = 0;
        self.params.A_f27.value    = 0;
        self.params.A_f210.value   = 0;
        self.params.A_f230.value   = 0;

        self.params.mu_f1.value     = 0.0023;
        self.params.mu_f2.value     = 0.0067;
        self.params.mu_f3.value     = -4.7273e-04;
        self.params.mu_f5.value     = 0.0034;
        self.params.mu_f7.value     = 0.0038;
        self.params.mu_f10.value    = 0.0028;
        self.params.mu_f30.value    = 0.0019;
        
        self.params.sig2_f1.value   = 2.3454e-05;
        self.params.sig2_f2.value   = 5.7570e-05;
        self.params.sig2_f3.value   = 7.2775e-05;
        self.params.sig2_f5.value   = 5.2910e-05;
        self.params.sig2_f7.value   = 4.9161e-05;
        self.params.sig2_f10.value  = 3.4692e-05;
        self.params.sig2_f30.value  = 2.2920e-05;
      elseif self.model_restr == 3
        self.params.A_f21.fixed    = true;
        self.params.A_f22.fixed    = true;
        self.params.A_f23.fixed    = true;
        self.params.A_f25.fixed    = true;
        self.params.A_f27.fixed    = true;
        self.params.A_f210.fixed   = true;
        self.params.A_f230.fixed   = true;

        self.params.A_f21.value    = 0;
        self.params.A_f22.value    = 0;
        self.params.A_f23.value    = 0;
        self.params.A_f25.value    = 0;
        self.params.A_f27.value    = 0;
        self.params.A_f210.value   = 0;
        self.params.A_f230.value   = 0;

        self.params.mu_f1.value     = 0.0023;
        self.params.mu_f2.value     = 0.0067;
        self.params.mu_f3.value     = -4.7273e-04;
        self.params.mu_f5.value     = 0.0034;
        self.params.mu_f7.value     = 0.0038;
        self.params.mu_f10.value    = 0.0028;
        self.params.mu_f30.value    = 0.0019;
        
        self.params.sig2_f1.value   = 2.3454e-05;
        self.params.sig2_f2.value   = 5.7570e-05;
        self.params.sig2_f3.value   = 7.2775e-05;
        self.params.sig2_f5.value   = 5.2910e-05;
        self.params.sig2_f7.value   = 4.9161e-05;
        self.params.sig2_f10.value  = 3.4692e-05;
        self.params.sig2_f30.value  = 2.2920e-05;
      end
      self.getTransformedSeries();
    end % end ForwardRate
    
    %% This function returns the parameters
    function pv = getPV(self)
      pv        = getPV@OptimProblem(self);
    end % end getPV
    
    %% This functions uses fminsearch to find the MLE parameters
    function [results] = fminsearch(self, lambda, varargin)
      results = fminsearch@OptimProblem(self, lambda, varargin);     
      [~,F] = self.objective(self.getPValues);
      
      switch self.method
        case 'MLE'
          self.mle.params = self.getPV;
          self.mle.F      = F;
          self.mle.out    = results;
      end
    end % end fminsearch

    %% This function computes the objective function; the only case considered so far is the MLE
    function [S,z] = objective(self, x, varargin)
      switch self.method
        case 'MLE'
          [logl,z] = self.getMLEfunction(x);
          S = -sum(logl);
      end
    end % end objective

    %% This function transforms the interest rates into their transformed versions
    function transfseries = getTransformedSeries(self)
      pv = self.getPV;

      % We compute the forward rates from the continuously compounded rates
      f0 =  self.series.Rate_3m;
      f1 = (self.series.Rate_1y*1   - self.series.Rate_3m/4)./0.75;
      f2 = (self.series.Rate_2y*2   - self.series.Rate_1y*1)./1;
      f3 = (self.series.Rate_3y*3   - self.series.Rate_2y*2)./1;
      f4 = (self.series.Rate_5y*5   - self.series.Rate_3y*3)./2;
      f5 = (self.series.Rate_7y*7   - self.series.Rate_5y*5)./2;
      f6 = (self.series.Rate_10y*10 - self.series.Rate_7y*7)./3;
      f7 = (self.series.Rate_30y*30 - self.series.Rate_10y*10)./20;

      ff = [f0,f1,f2,f3,f4,f5,f6,f7];
      
      % We apply the transformation 
      a0 = self.fbar - (self.fbar - self.c)*log(self.fbar - self.c);
      a1 = self.fbar - self.c;
      transfseries = ff.*(ff >= self.fbar) + (a0 + a1.*log(ff - self.c)).*(ff < self.fbar);
      transfseries = transfseries(:,2:end) - transfseries(:,1);
      
      % We store the transformed series in self.transfseries
      self.transfseries = transfseries;
    end % end getTransformedSeries
    
    %% This function gives matrix A_f
    function phis = getAf(self)
      pv = self.getPV;
      
      % The Wilkie model does not allow for factors
      if self.model_restr == 1
        phis = zeros(7,2);
      else
        phis = [pv.A_f11,  pv.A_f21; ...
                pv.A_f12,  pv.A_f22; ...
                pv.A_f13,  pv.A_f23; ...
                pv.A_f15,  pv.A_f25; ...
                pv.A_f17,  pv.A_f27; ...
                pv.A_f110, pv.A_f210; ...
                pv.A_f130, pv.A_f230];
      end
    end % end getAf

    %% This function computes the likelihood function for the forward rate model
    function [logl,F] = getMLEfunction(self,p)
      self.setPValues(p);
      pv = self.getPV;
      
      T    = height(self.series);

      % We compute the two factors from the transformed forward rates
      y    = self.transfseries;
      F(:,1) = y(:,7) - y(:,1);
      F(:,2) = y(:,7) + y(:,1) -2.*y(:,3);
      
      F = [pv.F_10,pv.F_20; F];
      
      % We organize the various vectors and matrices needed
      muf    = [pv.mu_f1,pv.mu_f2,pv.mu_f3,pv.mu_f5,pv.mu_f7,pv.mu_f10,pv.mu_f30]';
      sig2f = diag([pv.sig2_f1,pv.sig2_f2,pv.sig2_f3,pv.sig2_f5,pv.sig2_f7,pv.sig2_f10,pv.sig2_f30]);
      muF   = [pv.mu_F1,pv.mu_F2]';
      AF    = diag([pv.A_F1,pv.A_F2]);
      sig2F = diag([pv.sig2_F1,pv.sig2_F2]);
      Af    = self.getAf();
      
      % We compute the likelihood for both the factors and the transformed
      % forward observations
      for dt = 1:T
        et   = y(dt,:)'   - muf - Af*F(dt+1,:)';        
        yt   = F(dt+1,:)' - muF - AF*(F(dt,:)' - muF);  
        
        % Hack: if the likelihood calculation fails, then we return -Inf
        try
          logl(dt) = log(mvnpdf(et,zeros(7,1),sig2f)) + log(mvnpdf(yt,zeros(2,1),sig2F));
        catch
          logl(dt) = -Inf;
        end
      end
    end % end getMLEfunction
    
    %% This function computes the log prior for the dividend yield model
    function logprior = getlogPrior(self,pv)
      logprior =  log(normpdf(pv.mu_f1,0,1)) + ...
                  log(normpdf(pv.mu_f2,0,1)) + ...
                  log(normpdf(pv.mu_f3,0,1)) + ...
                  log(normpdf(pv.mu_f5,0,1)) + ...
                  log(normpdf(pv.mu_f7,0,1)) + ...
                  log(normpdf(pv.mu_f10,0,1)) + ...
                  log(normpdf(pv.mu_f30,0,1)) + ...
                  log(normpdf(pv.A_f11,0,1)) + ...
                  log(normpdf(pv.A_f12,0,1)) + ...
                  log(normpdf(pv.A_f13,0,1)) + ...
                  log(normpdf(pv.A_f15,0,1)) + ...
                  log(normpdf(pv.A_f17,0,1)) + ...
                  log(normpdf(pv.A_f110,0,1)) + ...
                  log(normpdf(pv.A_f130,0,1)) + ...
                  log(normpdf(pv.A_f21,0,1)) + ...
                  log(normpdf(pv.A_f22,0,1)) + ...
                  log(normpdf(pv.A_f23,0,1)) + ...
                  log(normpdf(pv.A_f25,0,1)) + ...
                  log(normpdf(pv.A_f27,0,1)) + ...
                  log(normpdf(pv.A_f210,0,1)) + ...
                  log(normpdf(pv.A_f230,0,1)) + ...
                  log(exppdf(pv.sig2_f1,1))  + log((pv.sig2_f1 >= 1e-7).*1)  + ... % Hack: We put some lower bounds on the variance so that the covariance matrix still exist...
                  log(exppdf(pv.sig2_f2,1))  + log((pv.sig2_f2 >= 1e-7).*1)  + ...
                  log(exppdf(pv.sig2_f3,1))  + log((pv.sig2_f3 >= 1e-7).*1)  + ...
                  log(exppdf(pv.sig2_f5,1))  + log((pv.sig2_f5 >= 1e-7).*1)  + ...
                  log(exppdf(pv.sig2_f7,1))  + log((pv.sig2_f7 >= 1e-7).*1)  + ...
                  log(exppdf(pv.sig2_f10,1)) + log((pv.sig2_f10 >= 1e-7).*1) + ...
                  log(exppdf(pv.sig2_f30,1)) + log((pv.sig2_f30 >= 1e-7).*1) + ...
                  log((pv.A_F1 <= 0.999).*(pv.A_F1 >= 0)) + ...                    % Hack: Otherwise, we might have some stationarity issues...
                  log((pv.A_F2 <= 0.999).*(pv.A_F2 >= 0)) + ...                    % Hack: Otherwise, we might have some stationarity issues...
                  log(normpdf(pv.mu_F1,0,1)) + ...
                  log(normpdf(pv.mu_F2,0,1)) + ...
                  log(exppdf(pv.sig2_F1,1)) + ...
                  log(exppdf(pv.sig2_F2,1)) + ...
                  log(normpdf(pv.F_10,0,1)) + ...
                  log(normpdf(pv.F_20,0,1));
    end % end getlogPrior
    
  end % end methods 

end % end ForwardRate