function [ data_, hRecover, Hfull ] = uwacJCEDD( Zm, fc, tau_p, b_p, data, T, M,...
    measuredCarriers, pilotCarriers, nullCarriers, pilotData, nullData,...
    mQAM, N0, structSparseAlgo, structJCEDD, iCount)
%Joint Channel Estimation & Data Detection (JCEDD)
% This function performs a CS based channel recovery and data detection for
% underwater acoustic communications. The following steps are executed:
%
% 1. Construct the Dictionary Matrix: Using the tentative estimate of data
% input to this function, and the grids of delays and doppler spread,
% construct the dictionary at the output of each partial interval
% demodulator
%
% 2. Sparse Channel Vector Recovery: Recover the channel vector either from
% the full interval FFT-demodulator output OR the partial interval
% FFT-demodulator output (decided based on the input measurements to this
% function).
%
% 3. Estimate Channel Matrix: Estimate the channel matrix as seen at each
% partial interval (FFT-) demodulator output. Then, depending on the
% measured data fed into the function (whether of size Kx1 or KxM), form
% the post-combined channel matrix. If the partial interval demodulator
% output is fed, use MMSE weights for linearly combining the outputs. Or
% else, simply sum up to get the channel matrix seen by full interval
% (FFT-) demodulator output.
%
% 4. MMSE estimate of data
%
% 5. Demapping to constellation
%
% INPUTS:
%   Zm        : Output data from either full interval (FFT-) demodulator OR
%             partial interval demodulator. For a full interval
%             demodulator, Zm is KxM where K is the total no. of
%             subcarriers (data + pilot +null). For a partial interval
%             demodulator, Zm is Kx1.
%
%   f         : vector of K sub-carrier frequencies
%
%   tau_p     : Npaths x 1 vector of path delays
%
%   b_p       : Ndopls x 1 vector of scale factors due to doppler
%
%   data      : K x 1 vector of data symbols
%
%   T         : Duration of an OFDM symbol.
%
%   M         : Number of partial FFTs formed
%
%   structSparseAlgo: SPARSEALGO structure
%
%   structJCEDD: 
%
%   iCount    : iteration count (for passing to uwacMeasurementMatrix)
%
% OUTPUTS:
%  data_      : Kx1 vector of refined estimate of data vector
%
%REFERENCE:
% [1] Arunkumar K.P. and Chandra R. Murthy, "Iterative Sparse Channel
% Estimation and Data Detection for Underwater Acoustic Communications
% Using Partial Interval Demodulation",  IEEE Transactions on Signal
% Processing ( Volume: 66 , Issue: 19 , Oct.1, 1 2018 )
%
%Author  : Arunkumar K. P.
%Address : Ph.D. Scholar,
%          Signal Processing for Communications Lab, ECE Department,
%          Indian Institute of Science, Bangalore, India-560 012.
%Email   : arunkumar@iisc.ac.in
%
%REVISION HISTORY
% Version : 2.1
% Last Revision: 05-02-2018
%
% This script/program is released under the Commons Creative Licence
% with Attribution Non-commercial Share Alike (by-nc-sa)
% http://creativecommons.org/licenses/by-nc-sa/3.0/
%
% Short Disclaimer: this script is for educational purpose only.
%++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

%% 1. Construct Dictionary Matrix (using the current data symbol estimate)
[K, ~] = size(Zm); % K = number of subcarriers

df = 1/T; %subcarrier spacing

f = fc + (-K/2:K/2-1)*df; %subcarriers

[ Am ] = uwacMeasurementMatrix( f, tau_p, b_p, data, T, M, iCount ); % PID dictionary

%% 2. Sparse Channel Recovery
if strcmpi(structJCEDD.SparseRecovery.Model,'FID') %Sparse Channel Recovery from FID output    
    % (2.1) Form full-FFT demodulator output...
    Z = sum(Zm,2);
    
    % (2.2) Make CS dictionary for full FFT demodulator output
    A = sum(Am,3);  % Dictionary for full-FFT demod measurements
    
    % (2.3) Sparse Channel Recovery from FID output
    %         % CS recovery of real valued sparse channel vector
    %         Areal = [real(A(measuredCarriers,:)); imag(A(measuredCarriers,:))];
    %         Zreal = [real(Z(measuredCarriers,:)); imag(Z(measuredCarriers,:))];        
    [hRecover, ~] =...
        OMPwithInitSupport(A(measuredCarriers,:), Z(measuredCarriers,:), structSparseAlgo.MaxPaths, []);
    %                 [hRecover, ~] =...
    %                     OMPwithInitSupport(Areal, Zreal, SPARSEALGO.MaxPaths, []);
    
    
else % Sparse Channel Recovery from PID output    
    % (2.1) Make CS Dictionary corresponding to PID model: Stack partial
    % FFT dictionaries into a 2-D matrix & observations into a vector
    for m = 1:M
        Am_measured((m-1)*length(measuredCarriers)+1:m*length(measuredCarriers),:) =...
            Am(measuredCarriers,:,m);
        Zm_measured((m-1)*length(measuredCarriers)+1:m*length(measuredCarriers)) =...
            Zm(measuredCarriers,m);
    end
    
    % (2.2) CS Recovery of sparse channel vector...
    %         Ar = real(Am_measured);     Ai = imag(Am_measured);
    %         Zr = real(Zm_measured(:));  Zi = imag(Zm_measured(:));
    %         Ac = [Ar -Ai; Ai Ar];
    %         Zc = [Zr ; Zi];        
    [hRecover, ~] = ...
        OMPwithInitSupport(Am_measured, Zm_measured(:), structSparseAlgo.MaxPaths, []);    
    
end

%% 3. Data Detection
% Ascertain the number of subcarriers & data subcarrier locations...
[K, ~] = size(Zm); % K = number of subcarriers
dataCarriers = sort( setdiff( 1:K, union( pilotCarriers, nullCarriers) ) );

% (3.0) Construct Channel matrix for the partial interval demodulator (PID)
% from the recovered channel vector...
%         tH = tic;
tau_p_repmat = repmat(tau_p(:),1, length(b_p));
tau_p_repmat = tau_p_repmat(:); %Now NtxNb = Ngrids
[tau_p0__, tauInd] = sort(tau_p_repmat(abs(hRecover)>0),'ascend');

b_p_repmat = repmat(b_p(:)',length(tau_p),1);
b_p_repmat = b_p_repmat(:); %Now NtxNb = Ngrids
b_p0__= b_p_repmat( abs(hRecover)>0 );
b_p0__ = b_p0__(tauInd);

A_p0__ = hRecover( abs(hRecover)>0 );
A_p0__ = A_p0__(tauInd);
A_p0__ = repmat(A_p0__(:), 1, M);

[ HmEstim ] = uwacChannelMatrix(fc, K, T, M, A_p0__, tau_p0__, b_p0__);
    %HmEstim is the channel matrix seen at partial interval demodulator output
    
Hfull = sum(HmEstim,3); %Hfull is one of the o/p arg (HmEstim is bulkier by a factor of M!)    
    %Hfull is the channel matrix seen at full interval demodulator output
    
if strcmpi(structJCEDD.DataDetection.Model,'FID')   
    % (3.1a) Construct the full interval demodulator output
    Z = sum(Zm,2);
    
    % (3.1b) Construct channel matrix for the full interval demodulator (FID)
    Hestim = Hfull; %channel matrix seen at full interval demodulator o/p    
    
    % (3.2) MMSE Estimation of data symbols
    data_ = zeros(K,1);    
    Z_ = Z - Hestim(:, pilotCarriers)*data(pilotCarriers);    
    Hestim_ = Hestim(:,dataCarriers);    
    data_(dataCarriers) = (Hestim_'*Hestim_+N0*eye(length(dataCarriers)))\(Hestim_'*Z_); %MMSE estimate of the data vector
    
    % (3.3) Demapping to M-QAM constellation
    data_(dataCarriers) = qammod( qamdemod(data_(dataCarriers),mQAM), mQAM );    
    data_(pilotCarriers) = pilotData;    
    data_(nullCarriers) = nullData;  %zero-out null data
    
else    
    % (3.1a) Form optimum combiner weights for PID outputs
    [~, hInd] = sort(abs(hRecover), 'descend');    
    tau_p0_ = tau_p_repmat(hInd(1:structSparseAlgo.MaxPaths));    
    a_p0_ = b_p_repmat(hInd(1:structSparseAlgo.MaxPaths));    
    A_p0_ = repmat(hRecover(hInd(1:structSparseAlgo.MaxPaths)),1,M);    
    [ Wopt_ ] = uwacOptimPartialFFTCombinerWts( f, tau_p0_, a_p0_, A_p0_, T, M, N0  );
    
    % (3.1b) Construct post-combined channel matrix...
    Wopt_H_ = repmat(Wopt_, 1, 1, K); %Resize weights to match channel matrix dimensions    
    Wopt_H_ = permute(Wopt_H_, [1 3 2]); %Reshape weights to match channel matrix dimensions    
    
    % (3.1c) Construct combined channel matrix for the partial interval demodulator (PID)
    Hestim = sum(HmEstim.*conj(Wopt_H_),3); % sum up PID channel matrices using combining Wts
        
    % (3.1d) Retain only diagonal terms of the post-combined channel matrix
    %         Hestim = diag( diag(Hestim) );%retain only diagonal entries
            
    % (3.1e) Construct post-combined measurements
    Z = sum(Zm.*conj(Wopt_),2);
    
    % (3.2) MMSE Estimation of data symbols
    data_ = zeros(K,1);    
    Z_ = Z - Hestim(:, pilotCarriers)*data(pilotCarriers);    
    Hestim_ = Hestim(:,dataCarriers);    
    data_(dataCarriers) = (Hestim_'*Hestim_+N0*eye(length(dataCarriers)))\(Hestim_'*Z_); %MMSE estimate of the data vector
    
    % (3.3) Demapping to M-QAM contellation
    data_(dataCarriers) = qammod( qamdemod(data_(dataCarriers),mQAM), mQAM );    
    data_(pilotCarriers) = pilotData;    
    data_(nullCarriers) = nullData;  %zero-out null data     
end

end
