/* Simple multi-threaded FLAC file encoder using libFLAC.  Written by
   Frederick Akalin by adapting example_c_encode_file written by Josh
   Coalson. */

#include <aio.h>
#include <fcntl.h>
#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/mman.h>
#include <sys/time.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>

#include "FLAC/metadata.h"
#include "FLAC/stream_encoder.h"

#define WAV_HEADER_SIZE 44
/* We want a constant block size for ease of comparison. */
#define BLOCKSIZE 4096
/* We limit the number of threads to avoid having to manage our aio
   calls. */
#define MAX_THREADS AIO_LISTIO_MAX
/* Used for the conversion to PCM in process_stream. */
#define INPUT_BUFSIZE 1024
/* Used in figuring out when to display progress messages */
#define PROGRESS_INTERVAL 100

/* Variables describing the input file. */
struct input_info {
  off_t byte_size;
  unsigned channels;
  unsigned bps;
  unsigned sample_byte_size;
  unsigned sample_rate;
  unsigned total_samples;
};

struct shard_info {
  /* Input variables. */
  int thread_number;
  FLAC__uint64 current_frame_number;
  struct input_info *input_info;
  FLAC__byte *bufin;
  unsigned num_samples;
  FLAC__byte *bufout;
  FLAC__byte *bufout_end;

  /* Modified variables. */
  unsigned byte_counter;
  FLAC__byte *bufout_cur;
  FLAC__byte *bufout_cur_max;
};

/* Given info about a shard, does all the encoding work for it. */
FLAC__bool encode_shard(struct shard_info *shard_info);

/* pthread wrapper for encode_shard(). */
void *encode_shard_thread(void *arg) {
  return (void *)encode_shard((struct shard_info *)arg);
}

int main(int argc, char *argv[]) {
  int status = EXIT_SUCCESS;
  int fdin;
  int fdout;
  int num_threads;
  struct input_info input_info;
  FLAC__byte *bufin;
  FLAC__byte *bufout;
  struct shard_info shard_infos[MAX_THREADS];
  pthread_t threads[MAX_THREADS];
  struct aiocb aiocbs[MAX_THREADS];

  /* Read in arguments. */

  if(argc != 3 && argc != 4) {
    printf("usage: %s infile.wav outfile.flac [thread count]\n", argv[0]);
    status = EXIT_FAILURE;
    goto cleanup_nothing;
  }

  fdin = open(argv[1], O_RDONLY);
  if (fdin < 0) {
    perror(argv[1]);
    status = EXIT_FAILURE;
    goto cleanup_nothing;
  }

  fdout = open(argv[2], O_WRONLY | O_CREAT | O_TRUNC);
  if (fdout < 0) {
    perror(argv[2]);
    status = EXIT_FAILURE;
    goto cleanup_fdin;
  }

  num_threads = (argc < 4) ? 1 : (int)strtol(argv[3], NULL, 10);
  if (num_threads < 1) num_threads = 1;
  else if (num_threads > MAX_THREADS) num_threads = MAX_THREADS;

  printf("Using %d encoding threads\n", num_threads);

  /* mmap input file in. */

  {
    struct stat buf;
    if (fstat(fdin, &buf) < 0) {
      perror("fstat");
      status = EXIT_FAILURE;
      goto cleanup_fdout;
    }
    input_info.byte_size = buf.st_size;
  }

  bufin = mmap(NULL, input_info.byte_size, PROT_READ, MAP_SHARED, fdin, 0);
  if (bufin == (FLAC__byte *)-1) {
    perror("mmap");
    status = EXIT_FAILURE;
    goto cleanup_fdout;
  }

  /* Set correct permissions on output file for convenience. */

  if (fchmod(fdout, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH) != 0) {
    perror("fchmod");
    status = EXIT_FAILURE;
    goto cleanup_fdout;
  }

  /* Read wav header and validate it. */

  if((input_info.byte_size < WAV_HEADER_SIZE) ||
     memcmp(bufin, "RIFF", 4) ||
     memcmp(bufin+8, "WAVEfmt \020\000\000\000\001\000\002\000", 16) ||
     memcmp(bufin+32, "\004\000\020\000data", 8)) {
    fprintf(stderr,
	    "ERROR: invalid/unsupported WAVE file; "
	    "only 16bps stereo WAVE in canonical form allowed\n");
    status = EXIT_FAILURE;
    goto cleanup_bufin;
  }
  input_info.channels = 2;
  input_info.bps = 16;
  input_info.sample_byte_size = input_info.channels * (input_info.bps / 8);
  input_info.sample_rate =
    ((((((unsigned)bufin[27] << 8) | bufin[26]) << 8) | bufin[25]) << 8) |
    bufin[24];
  input_info.total_samples =
    (((((((unsigned)bufin[43] << 8) | bufin[42]) << 8) | bufin[41]) << 8) |
     bufin[40]) / 4;

  /* We assume that the output data will be no larger than the input
     data. */

  bufout = malloc(input_info.byte_size - WAV_HEADER_SIZE);
  if (bufout == NULL) {
    perror("malloc");
    status = EXIT_FAILURE;
    goto cleanup_bufin;
  }

  /* Set up thread variables. */

  {
    /* These numbers are possible underestimates for the last thread. */
    unsigned blocks_per_thread =
      input_info.total_samples / num_threads / BLOCKSIZE;
    unsigned samples_per_thread = blocks_per_thread * BLOCKSIZE;
    unsigned bytes_per_thread =
      samples_per_thread * input_info.sample_byte_size;
    int i;
    for (i = 0; i < num_threads; ++i) {
      /* Messy logic to deal with the fact that the last thread won't
        have the same number of samples. */
      unsigned num_samples =
	((i < (num_threads - 1)) ?
	 samples_per_thread :
	 (input_info.total_samples -
	  (num_threads - 1) * samples_per_thread));

      shard_infos[i].thread_number = i;
      shard_infos[i].current_frame_number = i * blocks_per_thread;
      shard_infos[i].input_info = &input_info;
      shard_infos[i].bufin = bufin + WAV_HEADER_SIZE + i * bytes_per_thread;
      shard_infos[i].num_samples = num_samples;
      shard_infos[i].bufout = bufout + i * bytes_per_thread;
      shard_infos[i].bufout_end =
	shard_infos[i].bufout + num_samples * input_info.sample_byte_size;

      shard_infos[i].byte_counter = 0;
      shard_infos[i].bufout_cur = shard_infos[i].bufout;
      shard_infos[i].bufout_cur_max = shard_infos[i].bufout;
    }
  }

  /* Spawn the threads. */

  {
    int i;
    for (i = 0; i < num_threads; ++i) {
      if (pthread_create(&threads[i],
			 NULL,
			 &encode_shard_thread,
			 &shard_infos[i]) != 0) {
	perror("pthread_create");
	status = EXIT_FAILURE;
	goto cleanup_bufout;
      }
    }
  }

  /* Wait for each thread in sequence and queue up output writes. */

  {
    unsigned long byte_offset = 0;
    int i;
    /* We need to zero out any aiocb struct that we use before we fill
       in any members. */
    memset(aiocbs, 0, num_threads * sizeof(*aiocbs));
    for (i = 0; i < num_threads; ++i) {
      unsigned long bytes_written;
      void *ret;
      if (pthread_join(threads[i], &ret) != 0) {
	fprintf(stderr, "could not join thread %d\n", i);
	status = EXIT_FAILURE;
	goto cleanup_bufout;
      }
      if (!(FLAC__bool)ret) {
	fprintf(stderr, "thread %d did not succeed\n", i);
	status = EXIT_FAILURE;
	goto cleanup_bufout;
      }
      bytes_written = shard_infos[i].bufout_cur_max - shard_infos[i].bufout;
      printf("queueing up %lu bytes from thread %d\n",
	     bytes_written, i);
      aiocbs[i].aio_buf = shard_infos[i].bufout;
      aiocbs[i].aio_nbytes = bytes_written;
      aiocbs[i].aio_offset = byte_offset;
      aiocbs[i].aio_fildes = fdout;
      if (aio_write(&aiocbs[i]) != 0) {
	perror("aio_write");
	status = EXIT_FAILURE;
	goto cleanup_bufout;
      }
      byte_offset += bytes_written;
    }
  }

  /* Wait for all output writes to finish. */

  {
    struct timeval t_start;
    int i;
    if (gettimeofday(&t_start, NULL)) {
      perror("gettimeofday");
      status = EXIT_FAILURE;
      goto cleanup_bufout;
    }
    printf("writing to output file...\n");
    for (i = 0; i < num_threads; ++i) {
      const struct aiocb *aiocbp = &aiocbs[i];
      if (aio_suspend(&aiocbp, 1, NULL) < 0) {
	perror("aio_suspend");
	status = EXIT_FAILURE;
	goto cleanup_bufout;
      }
      if (aio_return(&aiocbs[i]) < 0) {
	perror("aio_return");
	status = EXIT_FAILURE;
	goto cleanup_bufout;
      }
    }

    /* Make sure we fsync so we get an accurate I/O time count,
     although we're assuming that all the encoding threads end at
     about the same time. */
    fsync(fdout);
    {
      struct timeval t_end;
      double dt;
      if (gettimeofday(&t_end, NULL)) {
	perror("gettimeofday");
	status = EXIT_FAILURE;
	goto cleanup_bufout;
      }
      dt =
        ((double)t_end.tv_sec + (double)t_end.tv_usec / 1000000) -
	((double)t_start.tv_sec + (double)t_start.tv_usec / 1000000);
      printf("writing took %f s\n", dt);
    }
  }

 cleanup_bufout:
  free(bufout);
 cleanup_bufin:
  munmap(bufin, input_info.byte_size);
 cleanup_fdout:
  close(fdout);
 cleanup_fdin:
  close(fdin);
 cleanup_nothing:
  return status;
}

/* Callbacks used by encode_shard(). */

FLAC__StreamEncoderWriteStatus
write_callback(const FLAC__StreamEncoder *encoder,
	       const FLAC__byte buffer[],
	       size_t bytes,
	       unsigned samples,
	       unsigned current_frame,
	       void *client_data) {
  struct shard_info *shard_info = (struct shard_info *)client_data;

  /* Metadata processing should be skipped for all but the first
     thread, but check anyway. */
  if ((shard_info->thread_number != 0) && (current_frame == 0)) {
    return FLAC__STREAM_ENCODER_WRITE_STATUS_OK;
  }

  if ((shard_info->bufout_cur + bytes) > shard_info->bufout_end) {
    fprintf(stderr, "output buffer not big enough\n");
    return FLAC__STREAM_ENCODER_WRITE_STATUS_FATAL_ERROR;
  }
  memcpy(shard_info->bufout_cur, buffer, bytes);
  shard_info->bufout_cur += bytes;
  shard_info->byte_counter += bytes;
  /* Hackish progress output. */
  if (shard_info->byte_counter >
      (shard_info->input_info->byte_size / PROGRESS_INTERVAL)) {
    printf("encode thread %d written %u bytes\n",
	   shard_info->thread_number,
	   shard_info->bufout_cur_max - shard_info->bufout);
    shard_info->byte_counter -=
      shard_info->input_info->byte_size / PROGRESS_INTERVAL;
  }
  /* Since bufout_cur may go backwards, we can't use it to figure out
     how much we've actually written. */
  if (shard_info->bufout_cur > shard_info->bufout_cur_max) {
      shard_info->bufout_cur_max = shard_info->bufout_cur;
  }
  return FLAC__STREAM_ENCODER_WRITE_STATUS_OK;
}

void metadata_callback(const FLAC__StreamEncoder *encoder,
		       const FLAC__StreamMetadata *metadata,
		       void *client_data) {
  struct shard_info *shard_info = (struct shard_info *)client_data;

  /* TODO: figure out a cleaner way of intercepting the metadata
     instead of exposing this private function. */
  void update_metadata_(const FLAC__StreamEncoder *encoder);

  /* Insert the real total_samples into the metadata and zero out the
   min/max framesize for now. */

  if (shard_info->thread_number == 0) {
    FLAC__StreamMetadata *mutable_metadata =
      (FLAC__StreamMetadata *)metadata;
    unsigned old_total_samples = metadata->data.stream_info.total_samples;
    unsigned old_min_framesize = metadata->data.stream_info.min_framesize;
    unsigned old_max_framesize = metadata->data.stream_info.max_framesize;

    mutable_metadata->data.stream_info.min_framesize = 0;
    mutable_metadata->data.stream_info.max_framesize = 0;
    mutable_metadata->data.stream_info.total_samples =
      shard_info->input_info->total_samples;

    update_metadata_(encoder);

    mutable_metadata->data.stream_info.total_samples = old_total_samples;
    mutable_metadata->data.stream_info.min_framesize = old_min_framesize;
    mutable_metadata->data.stream_info.max_framesize = old_max_framesize;
  }
}

FLAC__StreamEncoderSeekStatus
seek_callback(const FLAC__StreamEncoder *encoder,
	      FLAC__uint64 absolute_byte_offset,
	      void *client_data) {
  struct shard_info *shard_info = (struct shard_info *)client_data;
  if ((shard_info->bufout_cur + absolute_byte_offset) >
      shard_info->bufout_end) {
    fprintf(stderr, "seeking too far\n");
    return FLAC__STREAM_ENCODER_SEEK_STATUS_ERROR;
  }
  shard_info->bufout_cur = shard_info->bufout + absolute_byte_offset;
  return FLAC__STREAM_ENCODER_SEEK_STATUS_OK;
}

FLAC__StreamEncoderTellStatus
tell_callback(const FLAC__StreamEncoder *encoder,
	      FLAC__uint64 *absolute_byte_offset,
	      void *client_data) {
  struct shard_info *shard_info = (struct shard_info *)client_data;
  *absolute_byte_offset = shard_info->bufout_cur - shard_info->bufout;
  return FLAC__STREAM_ENCODER_TELL_STATUS_OK;
}

FLAC__bool encode_shard(struct shard_info *shard_info) {
  struct input_info *input_info = shard_info->input_info;
  FLAC__bool ok = true;
  FLAC__StreamEncoder *encoder;
  FLAC__StreamMetadata *metadata[2];

  printf("starting encode thread %d\n", shard_info->thread_number);

  encoder = FLAC__stream_encoder_new();
  if (encoder == NULL) {
    fprintf(stderr, "ERROR: allocating encoder\n");
    return false;
  }

  /* Set options from input_info. */

  ok &= FLAC__stream_encoder_set_blocksize(encoder, BLOCKSIZE);
  ok &= FLAC__stream_encoder_set_compression_level(encoder, 5);
  ok &= FLAC__stream_encoder_set_channels(encoder, input_info->channels);
  ok &= FLAC__stream_encoder_set_bits_per_sample(encoder, input_info->bps);
  ok &=
    FLAC__stream_encoder_set_sample_rate(encoder, input_info->sample_rate);
  ok &=
    FLAC__stream_encoder_set_total_samples_estimate(encoder,
						    input_info->
						    total_samples);

  /* Turn off md5 and verification for now. */

  ok &= FLAC__stream_encoder_set_do_md5(encoder, false);
  ok &= FLAC__stream_encoder_set_verify(encoder, false);

  /* The first thread is responsible for metadata. */
  if (ok && (shard_info->thread_number == 0)) {
    /* Copied from example encoder for ease of comparison. */
    FLAC__StreamMetadata_VorbisComment_Entry entry;
    if(
       (metadata[0] = FLAC__metadata_object_new(FLAC__METADATA_TYPE_VORBIS_COMMENT)) == NULL ||
       (metadata[1] = FLAC__metadata_object_new(FLAC__METADATA_TYPE_PADDING)) == NULL ||
       /* there are many tag (vorbiscomment) functions but these are convenient for this particular use: */
       !FLAC__metadata_object_vorbiscomment_entry_from_name_value_pair(&entry, "ARTIST", "Some Artist") ||
       !FLAC__metadata_object_vorbiscomment_append_comment(metadata[0], entry, /*copy=*/false) || /* copy=false: let metadata object take control of entry's allocated string */
       !FLAC__metadata_object_vorbiscomment_entry_from_name_value_pair(&entry, "YEAR", "1984") ||
       !FLAC__metadata_object_vorbiscomment_append_comment(metadata[0], entry, /*copy=*/false)
       ) {
      fprintf(stderr, "ERROR: out of memory or tag error\n");
      ok = false;
    }
  
    metadata[1]->length = 1234; /* set the padding length */
		
    ok &= FLAC__stream_encoder_set_metadata(encoder, metadata, 2);
  }

  /* Initialize encoder. */

  {
    FLAC__StreamEncoderInitStatus init_status =
      FLAC__stream_encoder_init_stream(encoder,
				       &write_callback,
				       &seek_callback,
				       &tell_callback,
				       &metadata_callback,
				       shard_info);
    if(init_status != FLAC__STREAM_ENCODER_INIT_STATUS_OK) {
      fprintf(stderr,
	      "ERROR: initializing encoder: %s\n",
	      FLAC__StreamEncoderInitStatusString[init_status]);
      ok = false;
    }
  }

  /* Must go after initialization for now.  Also, this is not
   compatible with verification. */
  ok &=
    FLAC__stream_encoder_set_current_frame_number(encoder,
						  shard_info->
						  current_frame_number);

  /* Do the encoding. (Also copied from example encoder.) */

  {
    FLAC__int32 pcm[INPUT_BUFSIZE * input_info->channels];
    size_t left = (size_t)shard_info->num_samples;
    FLAC__byte *bufin = shard_info->bufin;
    while(ok && left) {
      size_t need = (left>INPUT_BUFSIZE? (size_t)INPUT_BUFSIZE : (size_t)left);
      size_t need_bytes = need * input_info->sample_byte_size;
      /* convert the packed little-endian 16-bit PCM samples from WAVE into an interleaved FLAC__int32 buffer for libFLAC */
      size_t i;
      for(i = 0; i < need*input_info->channels; i++) {
	/* inefficient but simple and works on big- or little-endian machines */
	pcm[i] = (FLAC__int32)(((FLAC__int16)(FLAC__int8)bufin[2*i+1] << 8) | (FLAC__int16)bufin[2*i]);
      }
      /* feed samples to encoder */
      ok = FLAC__stream_encoder_process_interleaved(encoder, pcm, need);
      left -= need;
      bufin += need_bytes;
    }
  }

  ok &= FLAC__stream_encoder_finish(encoder);

  printf("encoding thread %d %s\n",
	 shard_info->thread_number,
	 ok ? "succeeded" : "FAILED");

  if (ok && (shard_info->thread_number == 0)) {
    /* now that encoding is finished, the metadata can be freed */
    FLAC__metadata_object_delete(metadata[0]);
    FLAC__metadata_object_delete(metadata[1]);
  }

  FLAC__stream_encoder_delete(encoder);

  return ok;
}
