#include "NeuralAmpModeler.h"

#ifdef SHR3D_SFX_CORE_NEURALAMPMODELER

#include "architecture.hpp"

#include <iostream>

const double kDCBlockerFrequency = 5.0;

NeuralAmpModeler::NeuralAmpModeler()
{
  activations::Activation::enable_fast_tanh();

  mNoiseGateTrigger.AddListener(&mNoiseGateGain);
}

NeuralAmpModeler::~NeuralAmpModeler()
{
  _DeallocateIOPointers();
}

void NeuralAmpModeler::ProcessBlock(const float* const* inputs, float** outputs, int nFrames, int sampleRate, int input, int gate, int bass, int middle, int treble, int output, bool gateEnabled, bool eqEnabled, bool normalizeEnabled, bool irEnabled, bool namEnabled)
{
  dSampleRate = double(sampleRate);
  const size_t numChannelsExternalIn = 2;
  const size_t numChannelsExternalOut = 2;
  const size_t numChannelsInternal = kNumChannelsInternal;
  const size_t numFrames = (size_t)nFrames;

  // Disable floating point denormals
  std::fenv_t fe_state;
  std::feholdexcept(&fe_state);
  disable_denormals();

  _PrepareBuffers(numChannelsInternal, numFrames);
  // Input is collapsed to mono in preparation for the NAM.
  _ProcessInput(inputs, numFrames, numChannelsExternalIn, numChannelsInternal, input);
  _ApplyDSPStaging();
  const bool noiseGateActive = gateEnabled /*GetParam(kNoiseGateActive)->Value()*/;
  const bool toneStackActive = eqEnabled;

  // Noise gate trigger
  float** triggerOutput = mInputPointers;
  if (noiseGateActive)
  {
    const double time = 0.01;

    const double threshold = double(gate) / 10.0; // GetParam...
    const double ratio = 0.1; // Quadratic...
    const double openTime = 0.005;
    const double holdTime = 0.01;
    const double closeTime = 0.05;
    const dsp::noise_gate::TriggerParams triggerParams(time, threshold, ratio, openTime, holdTime, closeTime);
    mNoiseGateTrigger.SetParams(triggerParams);
    mNoiseGateTrigger.SetSampleRate(dSampleRate);
    triggerOutput = mNoiseGateTrigger.Process(mInputPointers, numChannelsInternal, numFrames);
  }

  if (namEnabled && mModel != nullptr)
  {
    // TODO multi-channel processing; Issue
    // <ake sure it's multi-threaded or else this won't perform well!
    mModel->process(triggerOutput[0], mOutputPointers[0], nFrames);
    mModel->finalize_(nFrames);
    // Normalize loudness
    if (normalizeEnabled)
    {
      _NormalizeModelOutput(mOutputPointers, numChannelsInternal, numFrames);
    }
  }
  else
  {
    _FallbackDSP(triggerOutput, mOutputPointers, numChannelsInternal, numFrames);
  }
  // Apply the noise gate
  float** gateGainOutput =
    noiseGateActive ? mNoiseGateGain.Process(mOutputPointers, numChannelsInternal, numFrames) : mOutputPointers;

  float** toneStackOutPointers = gateGainOutput;
  if (toneStackActive)
  {
    // Translate params from knob 0-10 to dB.
    // Tuned ranges based on my ear. E.g. seems treble doesn't need nearly as
    // much swing as bass can use.

    const double bassGainDB = 4.0 * ((double(bass) / 10.0) - 5.0); // +/- 20
    const double midGainDB = 3.0 * ((double(middle) / 10.0) - 5.0); // +/- 15
    const double trebleGainDB = 2.0 * ((double(treble) / 10.0) - 5.0); // +/- 10

    const double bassFrequency = 150.0;
    const double midFrequency = 425.0;
    const double trebleFrequency = 1800.0;
    const double bassQuality = 0.707;
    // Wider EQ on mid bump up to sound less honky.
    const double midQuality = midGainDB < 0.0 ? 1.5 : 0.7;
    const double trebleQuality = 0.707;

    // Define filter parameters
    recursive_linear_filter::BiquadParams bassParams(dSampleRate, bassFrequency, bassQuality, bassGainDB);
    recursive_linear_filter::BiquadParams midParams(dSampleRate, midFrequency, midQuality, midGainDB);
    recursive_linear_filter::BiquadParams trebleParams(dSampleRate, trebleFrequency, trebleQuality, trebleGainDB);
    // Apply tone stack
    // Set parameters
    mToneBass.SetParams(bassParams);
    mToneMid.SetParams(midParams);
    mToneTreble.SetParams(trebleParams);
    float** bassPointers = mToneBass.Process(gateGainOutput, numChannelsInternal, numFrames);
    float** midPointers = mToneMid.Process(bassPointers, numChannelsInternal, numFrames);
    float** treblePointers = mToneTreble.Process(midPointers, numChannelsInternal, numFrames);
    toneStackOutPointers = treblePointers;
  }

  float** irPointers = toneStackOutPointers;
  if (mIR != nullptr && irEnabled)
    irPointers = mIR->Process(toneStackOutPointers, numChannelsInternal, numFrames);

  // And the HPF for DC offset (Issue 271)
  const double highPassCutoffFreq = kDCBlockerFrequency;
  // const double lowPassCutoffFreq = 20000.0;
  const recursive_linear_filter::HighPassParams highPassParams(dSampleRate, highPassCutoffFreq);
  // const recursive_linear_filter::LowPassParams lowPassParams(dSampleRate, lowPassCutoffFreq);
  mHighPass.SetParams(highPassParams);
  // mLowPass.SetParams(lowPassParams);
  float** hpfPointers = mHighPass.Process(irPointers, numChannelsInternal, numFrames);
  // sample** lpfPointers = mLowPass.Process(hpfPointers, numChannelsInternal, numFrames);

  // restore previous floating point state
  std::feupdateenv(&fe_state);

  // Let's get outta here
  // This is where we exit mono for whatever the output requires.
  _ProcessOutput(hpfPointers, outputs, numFrames, numChannelsInternal, numChannelsExternalOut, output);
}


// Private methods ============================================================

void NeuralAmpModeler::_AllocateIOPointers(const size_t nChans)
{
  if (mInputPointers != nullptr)
  {
    assert(false); // Tried to re-allocate mInputPointers without freeing
  }
  mInputPointers = new float* [nChans];
  if (mInputPointers == nullptr)
  {
    assert(false); // Failed to allocate pointer to input buffer!
  }
  if (mOutputPointers != nullptr)
  {
    assert(false); // Tried to re-allocate mOutputPointers without freeing
  }
  mOutputPointers = new float* [nChans];
  if (mOutputPointers == nullptr)
  {
    assert(false); // Failed to allocate pointer to output buffer!
  }
}

void NeuralAmpModeler::_ApplyDSPStaging()
{
  // Remove marked modules
  if (mShouldRemoveModel)
  {
    mModel = nullptr;
    mNAMPath.clear();
    mShouldRemoveModel = false;
    mCheckSampleRateWarning = true;
  }
  if (mShouldRemoveIR)
  {
    mIR = nullptr;
    mIRPath.clear();
    mShouldRemoveIR = false;
  }
  // Move things from staged to live
  if (mStagedModel != nullptr)
  {
    // Move from staged to active DSP
    mModel = std::move(mStagedModel);
    mStagedModel = nullptr;
    mNewModelLoadedInDSP = true;
    mCheckSampleRateWarning = true;
  }
  if (mStagedIR != nullptr)
  {
    mIR = std::move(mStagedIR);
    mStagedIR = nullptr;
  }
}

void NeuralAmpModeler::_CheckSampleRateWarning()
{
}

void NeuralAmpModeler::_DeallocateIOPointers()
{
  if (mInputPointers != nullptr)
  {
    delete[] mInputPointers;
    mInputPointers = nullptr;
  }
  if (mInputPointers != nullptr)
  {
    assert(false); // Failed to deallocate pointer to input buffer!
  }
  if (mOutputPointers != nullptr)
  {
    delete[] mOutputPointers;
    mOutputPointers = nullptr;
  }
  if (mOutputPointers != nullptr)
  {
    assert(false); // Failed to deallocate pointer to output buffer!
  }
}

void NeuralAmpModeler::_FallbackDSP(float** inputs, float** outputs, const size_t numChannels,
  const size_t numFrames)
{
  for (auto c = 0; c < numChannels; c++)
    for (auto s = 0; s < numFrames; s++)
      mOutputArray[c][s] = mInputArray[c][s];
}

void NeuralAmpModeler::_NormalizeModelOutput(float** buffer, const size_t numChannels, const size_t numFrames)
{
  if (!mModel)
    return;
  if (!mModel->HasLoudness())
    return;
  const double loudness = mModel->GetLoudness();
  const double targetLoudness = -18.0;
  const double gain = pow(10.0, (targetLoudness - loudness) / 20.0);
  for (size_t c = 0; c < numChannels; c++)
  {
    for (size_t f = 0; f < numFrames; f++)
    {
      buffer[c][f] *= gain;
    }
  }
}

void NeuralAmpModeler::_ResampleModelAndIR()
{
  // Model
  // TODO

  // IR
  if (mStagedIR != nullptr)
  {
    const double irSampleRate = mStagedIR->GetSampleRate();
    if (irSampleRate != dSampleRate)
    {
      const auto irData = mStagedIR->GetData();
      mStagedIR = std::make_unique<dsp::ImpulseResponse>(irData, dSampleRate);
    }
  }
  else if (mIR != nullptr)
  {
    const double irSampleRate = mIR->GetSampleRate();
    if (irSampleRate != dSampleRate)
    {
      const auto irData = mIR->GetData();
      mStagedIR = std::make_unique<dsp::ImpulseResponse>(irData, dSampleRate);
    }
  }
}

void NeuralAmpModeler::_StageModel(const std::string& modelPath)
{
  mStagedModel = get_dsp(modelPath);
  mNAMPath = modelPath;
  //SendControlMsgFromDelegate(kCtrlTagModelFileBrowser, kMsgTagLoadedModel, mNAMPath.GetLength(), mNAMPath.Get());
}

dsp::wav::LoadReturnCode NeuralAmpModeler::_StageIR(const std::string& irPath)
{
  // FIXME it'd be better for the path to be "staged" as well. Just in case the
  // path and the model got caught on opposite sides of the fence...
  dsp::wav::LoadReturnCode wavState = dsp::wav::LoadReturnCode::ERROR_OTHER;
  auto irPathU8 = std::filesystem::path(reinterpret_cast<const char*>(irPath.c_str()));
  mStagedIR = std::make_unique<dsp::ImpulseResponse>(irPathU8.string().c_str(), dSampleRate);
  if (mStagedIR == nullptr)
    mStagedIR = std::make_unique<dsp::ImpulseResponse>(irPathU8.string().c_str(), dSampleRate);
  wavState = mStagedIR->GetWavState();

  if (wavState == dsp::wav::LoadReturnCode::SUCCESS)
  {
    mIRPath = irPath;
    // SendControlMsgFromDelegate(kCtrlTagIRFileBrowser, kMsgTagLoadedIR, mIRPath.GetLength(), mIRPath.Get());
  }
  else
  {
    if (mStagedIR != nullptr)
    {
      mStagedIR = nullptr;
    }
    //mIRPath = previousIRPath;
    // SendControlMsgFromDelegate(kCtrlTagIRFileBrowser, kMsgTagLoadFailed);
  }

  return wavState;
}

size_t NeuralAmpModeler::_GetBufferNumChannels() const
{
  // Assumes input=output (no mono->stereo effects)
  return mInputArray.size();
}

size_t NeuralAmpModeler::_GetBufferNumFrames() const
{
  if (_GetBufferNumChannels() == 0)
    return 0;
  return mInputArray[0].size();
}

void NeuralAmpModeler::_PrepareBuffers(const size_t numChannels, const size_t numFrames)
{
  const bool updateChannels = numChannels != _GetBufferNumChannels();
  const bool updateFrames = updateChannels || (_GetBufferNumFrames() != numFrames);
  //  if (!updateChannels && !updateFrames)  // Could we do this?
  //    return;

  if (updateChannels)
  {
    _PrepareIOPointers(numChannels);
    mInputArray.resize(numChannels);
    mOutputArray.resize(numChannels);
  }
  if (updateFrames)
  {
    for (auto c = 0; c < mInputArray.size(); c++)
    {
      mInputArray[c].resize(numFrames);
      std::fill(mInputArray[c].begin(), mInputArray[c].end(), 0.0);
    }
    for (auto c = 0; c < mOutputArray.size(); c++)
    {
      mOutputArray[c].resize(numFrames);
      std::fill(mOutputArray[c].begin(), mOutputArray[c].end(), 0.0);
    }
  }
  // Would these ever get changed by something?
  for (auto c = 0; c < mInputArray.size(); c++)
    mInputPointers[c] = mInputArray[c].data();
  for (auto c = 0; c < mOutputArray.size(); c++)
    mOutputPointers[c] = mOutputArray[c].data();
}

void NeuralAmpModeler::_PrepareIOPointers(const size_t numChannels)
{
  _DeallocateIOPointers();
  _AllocateIOPointers(numChannels);
}

void NeuralAmpModeler::_ProcessInput(const float* const* inputs, const size_t nFrames, const size_t nChansIn, const size_t nChansOut, const int input)
{
  // We'll assume that the main processing is mono for now. We'll handle dual amps later.
  if (nChansOut != 1)
  {
    assert(false); // "Expected mono output, but " << nChansOut << " output channels are requested!"
  }

  // On the standalone, we can probably assume that the user has plugged into only one input and they expect it to be
  // carried straight through. Don't apply any division over nCahnsIn because we're just "catching anything out there."
  // However, in a DAW, it's probably something providing stereo, and we want to take the average in order to avoid
  // doubling the loudness.
  auto inputDouble = double(input) / 10.0;
  const double gain = pow(10.0, inputDouble / 20.0);

  // Assume _PrepareBuffers() was already called
  for (size_t c = 0; c < nChansIn; c++)
    for (size_t s = 0; s < nFrames; s++)
      if (c == 0)
        mInputArray[0][s] = gain * inputs[c][s];
      else
        mInputArray[0][s] += gain * inputs[c][s];
}

void NeuralAmpModeler::_ProcessOutput(float** inputs, float** outputs, const size_t nFrames,
  const size_t nChansIn, const size_t nChansOut, const int output)
{
  auto outputDouble = double(output) / 10.0;

  const double gain = pow(10.0, outputDouble / 20.0);
  // Assume _PrepareBuffers() was already called
  if (nChansIn != 1)
  {
    assert(false); // Plugin is supposed to process in mono.
  }
  // Broadcast the internal mono stream to all output channels.
  const size_t cin = 0;
  for (auto cout = 0; cout < nChansOut; cout++)
    for (auto s = 0; s < nFrames; s++)
      outputs[cout][s] = gain * inputs[cin][s];
}

#endif // SHR3D_SFX_CORE_NEURALAMPMODELER
