#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdbool.h>

#include "audio.h"
#include "audio_internal.h"
#include "audio_demuxer_adts.h"

typedef struct _DemuxerADTSContext {
    uint32_t outputBufferPosition;
    uint32_t missingBytesInLastPacket;
    AudioBuffer *outputBuffer;
    uint8_t headerData[7];
    uint8_t headerIndex;
} DemuxerADTSContext;

static uint32_t const sampleRateTable[16] = {
    96000, 88200, 64000, 48000, 44100, 32000, 24000, 22050,
    16000, 12000, 11025,  8000,  7350,     0,     0,     0
};

ADTSHeader demuxer_adts_decode_header(uint8_t data[7]) {
    ADTSHeader header = { 0 };
    header.valid = true;
    
    header.version = (data[1] >> 3) & 0x01;
    header.valid = ((data[1] >> 1) & 0x03) == 0; // layer is always set to 0
    header.has_crc = (data[1] & 0x01) == 0;
    header.type = ((data[2] >> 6) & 0x03) + 1;
    header.sampleRateIndex = (data[2] >> 2) & 0x03;
    header.is_private = (data[2] >> 1) & 0x01;
    header.channelMode = ((data[2] & 0x01) << 2) + ((data[3] >> 6) & 0x03);
    header.is_original = (data[3] >> 5) & 0x01;
    header.is_home = (data[3] >> 4) & 0x01;
    header.has_copyright = (data[3] >> 3) & 0x01;
    header.copyright_start = (data[3] >> 2) & 0x01;
    
    header.packetLength = (data[3] & 0x03) << 11;
    header.packetLength += data[4] << 3;
    header.packetLength += (data[5] >> 5) & 0x07;
    
    header.numAACFrames = (data[6] & 0x03) + 1;

    if (header.valid) {
        header.sampleRate = sampleRateTable[header.sampleRateIndex];
        header.packetLength -= 7; // Header
    } else {
        fprintf(stderr, "Invalid header!\n");
    }
    
    return header;
}

static inline AudioPipelineStatus demuxer_adts_emit(AudioPipelineElement *self) {
    DemuxerADTSContext *context = (DemuxerADTSContext *)self->ctx;
    
    //fprintf(stderr, "INFO: Emitting packet of size %d\n", context->outputBufferPosition);
    uint32_t buf_sz = context->outputBuffer->buf_size;
    context->outputBuffer->buf_size = context->outputBufferPosition;
    AudioPipelineStatus result = self->next->push(self->next, context->outputBuffer);
    context->outputBuffer->buf_size = buf_sz;
    context->outputBufferPosition = 0;
    return result;
}

AudioPipelineStatus demuxer_adts_push(AudioPipelineElement *self, AudioBuffer *buffer) {
    DemuxerADTSContext *context = (DemuxerADTSContext *)self->ctx;
    uint32_t start = 0;
    bool sync = true;
    
    if (context->missingBytesInLastPacket) {
        if (buffer->buf_size < context->missingBytesInLastPacket) {
            uint32_t remaining_bytes = (context->missingBytesInLastPacket - buffer->buf_size);
            //fprintf(stderr, "INFO: Last buffer was too short, copying %d bytes, %d remaining \n", buffer->buf_size, remaining_bytes);
            memcpy(context->outputBuffer->data + context->outputBufferPosition, buffer->data, buffer->buf_size);
            context->outputBufferPosition += buffer->buf_size;
            context->missingBytesInLastPacket = remaining_bytes;
            return PipelineBuffering;
        }
        //fprintf(stderr, "INFO: Last buffer was too short, copying %d bytes\n", context->missingBytesInLastPacket);
        memcpy(context->outputBuffer->data + context->outputBufferPosition, buffer->data, context->missingBytesInLastPacket);
        context->outputBufferPosition += context->missingBytesInLastPacket;
        demuxer_adts_emit(self);
        start = context->missingBytesInLastPacket;
        context->missingBytesInLastPacket = 0;
    }
    
    for (uint32_t i = start; i < buffer->buf_size; i++) {
        switch (context->headerIndex) {
            case 0:
                if (buffer->data[i] != 0xff) {
                    if (sync) {
                        fprintf(stderr, "WARN: Sync lost at %d\n", i);
                        sync = false;
                    }
                    continue;
                }
                context->headerData[context->headerIndex++] = buffer->data[i];
                continue;
            case 1:
                if (buffer->data[i] < 0xf0) {
                    context->headerIndex = 0;
                    continue;
                }
                context->headerData[context->headerIndex++] = buffer->data[i];
                continue;
            case 2:
            case 3:
            case 4:
            case 5:
            case 6:
                context->headerData[context->headerIndex++] = buffer->data[i];
                continue;
            default:
                break;
        }
        context->headerIndex = 0;
        
        // sync marker found, try to decode the header
        ADTSHeader header = demuxer_adts_decode_header(context->headerData);
        if (!header.valid) {
            continue;
        }
        sync = true;
        if (i + header.packetLength < buffer->buf_size) {
            // FIXME: realloc output Buffer size if too small
            //fprintf(stderr, "INFO: Found frame sync at %d, copying %d bytes\n", i, header.packetLength);
            memcpy(context->outputBuffer->data + context->outputBufferPosition, context->headerData, 7);
            context->outputBufferPosition += 7;
            memcpy(context->outputBuffer->data + context->outputBufferPosition, buffer->data + i, header.packetLength);
            i += header.packetLength - 1;
            context->outputBufferPosition += header.packetLength;
            demuxer_adts_emit(self);
        } else {
            uint32_t remaining_bytes = (buffer->buf_size - i);
            //fprintf(stderr, "INFO: Found frame sync at %d, buffer too short copying %d bytes\n", i, remaining_bytes);
            memcpy(context->outputBuffer->data + context->outputBufferPosition, context->headerData, 7);
            context->outputBufferPosition += 7;
            memcpy(context->outputBuffer->data + context->outputBufferPosition, buffer->data + i, remaining_bytes);
            context->outputBufferPosition += remaining_bytes;
            context->missingBytesInLastPacket = header.packetLength - remaining_bytes;
            return PipelineBuffering;
        }
    }

    return PipelineRunning;
}

AudioPipelineStatus demuxer_adts_link(AudioPipelineElement *self, AudioPipelineElement *source) {
    if ((source->sample_rate != 0) || (source->channels != 0) || (source->bits_per_sample != 0) || (source->type != AudioElementSource)) {
        fprintf(stderr, "ERROR: ADTS demuxer can only link to a data source, not %s!\n", source->describe(source));
        return PipelineError;
    }

    source->next = self;
    return PipelineStopped;
}

char *demuxer_adts_describe(AudioPipelineElement *self) {
    return "adts demuxer";
}

void demuxer_adts_destroy(AudioPipelineElement *self) {
    DemuxerADTSContext *context = (DemuxerADTSContext *)self->ctx;
    if (context->outputBuffer) {
        free(context->outputBuffer);
    }
    free(context);
    free(self);
}

AudioPipelineElement *audio_demuxer_adts(void) {
    AudioPipelineElement *self = calloc(1, sizeof(AudioPipelineElement));
    DemuxerADTSContext *context = calloc(1, sizeof(DemuxerADTSContext));
    
    context->outputBuffer = alloc_audio_buffer(
        4 /* max frames per adts packet */
        * 768 /* max aac packet size per channel */
        * 2 /* channels */
    );
    self->ctx = context;
    self->describe = demuxer_adts_describe;
    self->start = filter_start_nop;
    self->reset = filter_reset_nop;
    self->stop = filter_stop_nop;
    self->push = demuxer_adts_push;
    self->link = demuxer_adts_link;
    self->destroy = demuxer_adts_destroy;
    self->type = AudioElementDemuxer;
    
    return self;
}