#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdbool.h>

#include "audio.h"
#include "audio_internal.h"
#include "audio_demuxer_mp3.h"

typedef struct _DemuxerMP3Context {
    uint32_t outputBufferPosition;
    uint32_t missingBytesInLastPacket;
    AudioBuffer *outputBuffer;
    uint8_t headerData[4];
    uint8_t headerIndex;
} DemuxerMP3Context;

static uint32_t const bitrateTable[5][16] = {
    /* MPEG-1 */
    { 0,  32000,  64000,  96000, 128000, 160000, 192000, 224000,  /* Layer I   */
         256000, 288000, 320000, 352000, 384000, 416000, 448000, 0 },
    { 0,  32000,  48000,  56000,  64000,  80000,  96000, 112000,  /* Layer II  */
         128000, 160000, 192000, 224000, 256000, 320000, 384000, 0 },
    { 0,  32000,  40000,  48000,  56000,  64000,  80000,  96000,  /* Layer III */
         112000, 128000, 160000, 192000, 224000, 256000, 320000, 0 },

    /* MPEG-2 LSF */
    { 0,  32000,  48000,  56000,  64000,  80000,  96000, 112000,  /* Layer I   */
         128000, 144000, 160000, 176000, 192000, 224000, 256000, 0 },
    { 0,   8000,  16000,  24000,  32000,  40000,  48000,  56000,  /* Layers II & III */
          64000,  80000,  96000, 112000, 128000, 144000, 160000, 0 }
};

static uint32_t const sampleRateTable[3][3] = {
    /* MPEG-1 */
    { 44100, 48000, 32000 },
    /* MPEG-2 */
    { 22050, 24000, 16000 },
    /* MPEG-2.5 */
    { 11025, 12000,  8000 }
};


MP3Header demuxer_mp3_decode_header(uint8_t data[4]) {
    MP3Header header = { 0 };
    
    header.version = (data[1] >> 3) & 0x03;
    header.layer = (data[1] >> 1) & 0x03;
    header.has_crc = (data[1] & 0x01) == 0;
    header.bitrateIndex = (data[2] >> 4) & 0x0f;
    header.sampleRateIndex = (data[2] >> 2) & 0x03;
    header.has_padding = (data[2] >> 1) & 0x01;
    header.is_private = data[2] & 0x01;
    header.channelMode = (data[3] >> 6) & 0x03;
    header.jointStereoModeExtension = (data[3] >> 4) & 0x03;
    header.has_copyright = (data[3] >> 3) & 0x01;
    header.is_original = (data[3] >> 2) & 0x01;
    header.emphasis = data[3] & 0x03;
    
    header.valid = (header.version != MPEGVersionReserved) && (header.layer != MPEGLayerReserved) && (header.bitrateIndex != 0x0f) && (header.sampleRateIndex != 0x03);
    
    if (header.valid) {
        if (header.version == MPEGVersion1) {
            switch (header.layer) {
                case MPEGLayer1:
                    header.bitrate = bitrateTable[0][header.bitrateIndex];
                    break;
                case MPEGLayer2:
                    header.bitrate = bitrateTable[1][header.bitrateIndex];
                    break;
                case MPEGLayer3:
                    header.bitrate = bitrateTable[2][header.bitrateIndex];
                    break;
                case MPEGLayerReserved:
                    break;
            }
        }
        if ((header.version == MPEGVersion2) || (header.version == MPEGVersion2_5)) {
            switch (header.layer) {
                case MPEGLayer1:
                    header.bitrate = bitrateTable[3][header.bitrateIndex];
                    break;
                case MPEGLayer2:
                case MPEGLayer3:
                    header.bitrate = bitrateTable[4][header.bitrateIndex];
                    break;
                case MPEGLayerReserved:
                    break;
            }
        }
        
        switch (header.version) {
            case MPEGVersion1:
                header.sampleRate = sampleRateTable[0][header.sampleRateIndex];
                break;
            case MPEGVersion2:
                header.sampleRate = sampleRateTable[1][header.sampleRateIndex];
                break;
            case MPEGVersion2_5:
                header.sampleRate = sampleRateTable[2][header.sampleRateIndex];
                break;
            case MPEGVersionReserved:
                break;
        }
        
        header.packetLength = 144 * ((float)header.bitrate / (float)header.sampleRate);
        if (header.has_padding) {
            if (header.layer == MPEGLayer1) {
                header.packetLength += 4;
            } else {
                header.packetLength += 1;
            }
        }
        if (header.has_crc) {
            header.packetLength += 2;
        }
        header.packetLength -= 4; // Header
        //fprintf(stderr, "INFO: Packet length: %d\n", header.packetLength);
        if ((header.packetLength > 2016) || (header.packetLength < 96)) {
            header.valid = false;
        }
    } else {
        fprintf(stderr, "Invalid header!\n");
    }
    
    return header;
}

static inline AudioPipelineStatus demuxer_mp3_emit(AudioPipelineElement *self) {
    DemuxerMP3Context *context = (DemuxerMP3Context *)self->ctx;
    
    //fprintf(stderr, "INFO: Emitting packet of size %d\n", context->outputBufferPosition);
    uint32_t buf_sz = context->outputBuffer->buf_size;
    context->outputBuffer->buf_size = context->outputBufferPosition;
    AudioPipelineStatus result = self->next->push(self->next, context->outputBuffer);
    context->outputBuffer->buf_size = buf_sz;
    context->outputBufferPosition = 0;
    return result;
}

AudioPipelineStatus demuxer_mp3_push(AudioPipelineElement *self, AudioBuffer *buffer) {
    DemuxerMP3Context *context = (DemuxerMP3Context *)self->ctx;
    uint32_t start = 0;
    bool sync = true;
    
    if (context->missingBytesInLastPacket) {
        // FIXME: what if the buffer is too small
        if (buffer->buf_size < context->missingBytesInLastPacket) {
            uint32_t remaining_bytes = (context->missingBytesInLastPacket - buffer->buf_size);
            //fprintf(stderr, "INFO: Last buffer was too short, copying %d bytes, %d remaining \n", buffer->buf_size, remaining_bytes);
            memcpy(context->outputBuffer->data + context->outputBufferPosition, buffer->data, buffer->buf_size);
            context->outputBufferPosition += buffer->buf_size;
            context->missingBytesInLastPacket = remaining_bytes;
            return PipelineBuffering;
        }
        //fprintf(stderr, "INFO: Last buffer was too short, copying %d bytes\n", context->missingBytesInLastPacket);
        memcpy(context->outputBuffer->data + context->outputBufferPosition, buffer->data, context->missingBytesInLastPacket);
        context->outputBufferPosition += context->missingBytesInLastPacket;
        demuxer_mp3_emit(self);
        start = context->missingBytesInLastPacket;
        context->missingBytesInLastPacket = 0;
    }
    
    for (uint32_t i = start; i < buffer->buf_size; i++) {
        switch (context->headerIndex) {
            case 0:
                if (buffer->data[i] != 0xff) {
                    if (sync) {
                        fprintf(stderr, "WARN: Sync lost at %d\n", i);
                        sync = false;
                    }
                    continue;
                }
                context->headerData[context->headerIndex++] = buffer->data[i];
                continue;
            case 1:
                if (buffer->data[i] < 0xe0) {
                    context->headerIndex = 0;
                    continue;
                }
                context->headerData[context->headerIndex++] = buffer->data[i];
                continue;
            case 2:
            case 3:
                context->headerData[context->headerIndex++] = buffer->data[i];
                continue;
            default:
                break;
        }
        context->headerIndex = 0;
        
        // sync marker found, try to decode the header
        MP3Header header = demuxer_mp3_decode_header(context->headerData);
        if (!header.valid) {
            continue;
        }
        sync = true;
        if (i + header.packetLength < buffer->buf_size) {
            // FIXME: realloc output Buffer size if too small
            //fprintf(stderr, "INFO: Found frame sync at %d, copying %d bytes\n", i, header.packetLength);
            memcpy(context->outputBuffer->data + context->outputBufferPosition, context->headerData, 4);
            context->outputBufferPosition += 4;
            memcpy(context->outputBuffer->data + context->outputBufferPosition, buffer->data + i, header.packetLength);
            i += header.packetLength - 1;
            context->outputBufferPosition += header.packetLength;
            demuxer_mp3_emit(self);
        } else {
            uint32_t remaining_bytes = (buffer->buf_size - i);
            //fprintf(stderr, "INFO: Found frame sync at %d, buffer too short copying %d bytes\n", i, remaining_bytes);
            memcpy(context->outputBuffer->data + context->outputBufferPosition, context->headerData, 4);
            context->outputBufferPosition += 4;
            memcpy(context->outputBuffer->data + context->outputBufferPosition, buffer->data + i, remaining_bytes);
            context->outputBufferPosition += remaining_bytes;
            context->missingBytesInLastPacket = header.packetLength - remaining_bytes;
            return PipelineBuffering;
        }
    }

    return PipelineRunning;
}

AudioPipelineStatus demuxer_mp3_link(AudioPipelineElement *self, AudioPipelineElement *source) {
    if ((source->sample_rate != 0) || (source->channels != 0) || (source->bits_per_sample != 0) || (source->type != AudioElementSource)) {
        fprintf(stderr, "ERROR: MP3 demuxer can only link to a data source, not %s!\n", source->describe(source));
        return PipelineError;
    }

    source->next = self;
    return PipelineStopped;
}

char *demuxer_mp3_describe(AudioPipelineElement *self) {
    return "mp3 demuxer";
}

void demuxer_mp3_destroy(AudioPipelineElement *self) {
    DemuxerMP3Context *context = (DemuxerMP3Context *)self->ctx;
    if (context->outputBuffer) {
        free(context->outputBuffer);
    }
    free(context);
    free(self);
}

AudioPipelineElement *audio_demuxer_mp3(void) {
    AudioPipelineElement *self = calloc(1, sizeof(AudioPipelineElement));
    DemuxerMP3Context *context = calloc(1, sizeof(DemuxerMP3Context));
    
    context->outputBuffer = alloc_audio_buffer(2048);
    self->ctx = context;
    self->describe = demuxer_mp3_describe;
    self->start = filter_start_nop;
    self->reset = filter_reset_nop;
    self->stop = filter_stop_nop;
    self->push = demuxer_mp3_push;
    self->link = demuxer_mp3_link;
    self->destroy = demuxer_mp3_destroy;
    self->type = AudioElementDemuxer;
    
    return self;
}