#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>

#include "audio.h"
#include "audio_internal.h"
#include "audio_demuxer_mp3.h"
#include "deps/mp3/stream.h"
#include "deps/mp3/frame.h"
#include "deps/mp3/synth.h"

typedef struct mad_stream MADStream;
typedef struct mad_frame MADFrame;
typedef struct mad_synth MADSynth;

typedef struct _DecoderMP3Context {
    MADStream *stream;
    MADFrame *frame;
    MADSynth *synth;
    AudioBuffer *outputBuffer;
} DecoderMP3Context;

static inline int16_t scale(mad_fixed_t sample) {
  /* round */
  sample += (1L << (MAD_F_FRACBITS - 16));

  /* clip */
  if (sample >= MAD_F_ONE)
    sample = MAD_F_ONE - 1;
  else if (sample < -MAD_F_ONE)
    sample = -MAD_F_ONE;

  /* quantize */
  return sample >> (MAD_F_FRACBITS + 1 - 16);
}

AudioPipelineStatus decoder_mp3_push(AudioPipelineElement *self, AudioBuffer *buffer) {
    DecoderMP3Context *context = (DecoderMP3Context *)self->ctx;
    int result = 0;
    
    mad_stream_buffer(context->stream, buffer->data, buffer->buf_size);

    result = mad_frame_decode(context->frame, context->stream);
    if (result == -1) {
        fprintf(stderr, "ERROR: frame decoding error 0x%04x (%s)\n",
                context->stream->error, mad_stream_errorstr(context->stream));
        if (!MAD_RECOVERABLE(context->stream->error)) {
            mad_frame_mute(context->frame);
            return PipelineError;
        }
        return PipelineRunning;
    }
    mad_synth_frame(context->synth, context->frame);

    bool relink_needed = false;
    if ((self->channels == 0) || (self->channels != context->synth->pcm.channels)) {
        self->channels = context->synth->pcm.channels;
        relink_needed = true;
    }
    if ((self->sample_rate == 0) || (self->sample_rate != context->synth->pcm.samplerate)) {
        self->sample_rate = context->synth->pcm.samplerate;
        relink_needed = true;
    }
    
    if (relink_needed) {
        AudioPipelineStatus result = self->next->link(self->next, self);
        if (result != PipelineStopped) {
            fprintf(stderr, "ERROR: Re-link failed\n");
            return PipelineError;
        }
    }
    
    uint16_t numSamples = context->synth->pcm.length;
    if ((context->outputBuffer != NULL) && (context->outputBuffer->buf_size < numSamples * self->channels * 2)) {
        // realloc buffer
        free_audio_buffer(context->outputBuffer);
        context->outputBuffer = NULL;
    }
    if (context->outputBuffer == NULL) {
        // allocate output buffer
        context->outputBuffer = alloc_audio_buffer(numSamples * self->channels * 2);
        if (context->outputBuffer == NULL) {
            fprintf(stderr, "ERROR: Cannot allocate output buffer of size %d!\n", numSamples * self->channels * 2);
            return PipelineError;
        }
    }
    
    // Scale samples to 16 bit
    mad_fixed_t const *left = context->synth->pcm.samples[0];
    mad_fixed_t const *right = context->synth->pcm.samples[1];
    uint32_t bufferPos = 0;
    while (numSamples--) {
        signed int sample;
        
        /* output sample(s) in 16-bit signed little-endian PCM */
        
        sample = scale(*left++);
        context->outputBuffer->data[bufferPos++] = (sample >> 0) & 0xff;
        context->outputBuffer->data[bufferPos++] = (sample >> 8) & 0xff;
        
        if (context->synth->pcm.channels == 2) {
            sample = scale(*right++);
            context->outputBuffer->data[bufferPos++] = (sample >> 0) & 0xff;
            context->outputBuffer->data[bufferPos++] = (sample >> 8) & 0xff;
        }
    }
    
    // run next element of the pipeline
    uint32_t buf_sz = context->outputBuffer->buf_size;
    context->outputBuffer->buf_size = bufferPos;
    AudioPipelineStatus ret = self->next->push(self->next, context->outputBuffer);
    if ((ret != PipelineRunning) && (ret != PipelineBuffering)) {
        return ret;
    }
    context->outputBuffer->buf_size = buf_sz;

    return PipelineRunning;
}

AudioPipelineStatus decoder_mp3_link(AudioPipelineElement *self, AudioPipelineElement *source) {
    if ((source->sample_rate != 0) || (source->channels != 0) || (source->bits_per_sample != 0) || (!((source->type == AudioElementSource) || (source->type == AudioElementDemuxer)))) {
        fprintf(stderr, "ERROR: MP3 decoder can only link to a data source, not %s!\n", source->describe(source));
        return PipelineError;
    }

    source->next = self;
    return PipelineStopped;
}

char *decoder_mp3_describe(AudioPipelineElement *self) {
    return "libmad MP3 decoder";
}

void decoder_mp3_destroy(AudioPipelineElement *self) {
    DecoderMP3Context *context = (DecoderMP3Context *)self->ctx;
    if (context->outputBuffer) {
        free_audio_buffer(context->outputBuffer);
    }
    mad_synth_finish(context->synth);
    mad_frame_finish(context->frame);
    mad_stream_finish(context->stream);
    
    free(context->synth);
    free(context->frame);
    free(context->stream);
    
    free(context);
    free(self);
}

AudioPipelineElement *audio_decoder_mp3(void) {
    AudioPipelineElement *self = calloc(1, sizeof(AudioPipelineElement));
    DecoderMP3Context *context = calloc(1, sizeof(DecoderMP3Context));
    
    context->stream = malloc(sizeof(MADStream));
    context->frame = malloc(sizeof(MADFrame));
    context->synth = malloc(sizeof(MADSynth));
    
    mad_stream_init(context->stream);
    mad_frame_init(context->frame);
    mad_synth_init(context->synth);
    
    mad_stream_options(context->stream, 0);
    
    self->bits_per_sample = 16;

    self->ctx = context;
    self->describe = decoder_mp3_describe;
    self->start = filter_start_nop;
    self->reset = filter_reset_nop;
    self->stop = filter_stop_nop;
    self->push = decoder_mp3_push;
    self->link = decoder_mp3_link;
    self->destroy = decoder_mp3_destroy;
    self->type = AudioElementDecoder;
    
    return self;
}