Skip to content

Commit

Permalink
Merge pull request #2352 from adeepH/logits
Browse files Browse the repository at this point in the history
Getting logits to the client code
  • Loading branch information
reuben committed Mar 24, 2023
2 parents 946deb0 + 4deff6b commit e7d2af9
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 9 deletions.
8 changes: 7 additions & 1 deletion native_client/args.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ bool init_from_array_of_bytes = false;
int json_candidate_transcripts = 3;

int stream_size = 0;

bool keep_emissions = false;
int extended_stream_size = 0;

char* hot_words = NULL;
Expand All @@ -59,6 +59,7 @@ void PrintHelp(const char* bin)
"\t--lm_beta LM_BETA\t\tValue for language model beta param (float)\n"
"\t-t\t\t\t\tRun in benchmark mode, output mfcc & inference time\n"
"\t--extended\t\t\tOutput string from extended metadata\n"
"\t--keep_emissions\t\t\tSave the output of the acoustic model\n"
"\t--json\t\t\t\tExtended output, shows word timings as JSON\n"
"\t--candidate_transcripts NUMBER\tNumber of candidate transcripts to include in JSON output\n"
"\t--stream size\t\t\tRun in stream mode, output intermediate results\n"
Expand All @@ -85,6 +86,7 @@ bool ProcessArgs(int argc, char** argv)
{"lm_beta", required_argument, nullptr, 'd'},
{"t", no_argument, nullptr, 't'},
{"extended", no_argument, nullptr, 'e'},
{"keep_emissions", no_argument, nullptr, 'L'},
{"json", no_argument, nullptr, 'j'},
{"init_from_bytes", no_argument, nullptr, 'B'},
{"candidate_transcripts", required_argument, nullptr, 150},
Expand Down Expand Up @@ -140,6 +142,10 @@ bool ProcessArgs(int argc, char** argv)
extended_metadata = true;
break;

case 'L':
keep_emissions = true;
break;

case 'j':
json_output = true;
break;
Expand Down
38 changes: 36 additions & 2 deletions native_client/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,36 @@ MetadataToJSON(Metadata* result)
}
}

if (keep_emissions && result->emissions != NULL) {
int num_timesteps = result->emissions->num_timesteps;
int num_symbols = result->emissions->num_symbols;
int class_dim = num_symbols + 1;
const char **symbol_table = result->emissions->symbols;
out_string << ",\n" << R"("alphabet")" << ":[";
for(int i = 0; i < class_dim; i++) {
out_string << "\"" << symbol_table[i] << "\"";
if(i < class_dim - 1) {
out_string << ", ";
}
}
out_string << "],\n" << R"("emissions")" << ":[\n";
for(int i = 0; i < num_timesteps; i++) {
out_string << "[";
for(int j = 0; j < num_symbols; j++) {
out_string << result->emissions->emissions[i * num_symbols + j];
if(j < num_symbols - 1) {
out_string << ", ";
}
}
out_string << "]";
if(i < num_timesteps - 1) {
out_string << ",";
}
out_string << "\n";
}
out_string << "\n]";
}

out_string << "\n}\n";

return strdup(out_string.str().c_str());
Expand All @@ -169,14 +199,18 @@ LocalDsSTT(ModelState* aCtx, const short* aBuffer, size_t aBufferSize,
clock_t stt_start_time = clock();

// sphinx-doc: c_ref_inference_start
if (extended_output) {
if (extended_output && !keep_emissions) {
Metadata *result = STT_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize, 1);
res.string = CandidateTranscriptToString(&result->transcripts[0]);
STT_FreeMetadata(result);
} else if (json_output) {
} else if (json_output && !keep_emissions) {
Metadata *result = STT_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize, json_candidate_transcripts);
res.string = MetadataToJSON(result);
STT_FreeMetadata(result);
} else if (keep_emissions) {
Metadata *result = STT_SpeechToTextWithEmissions(aCtx, aBuffer, aBufferSize, json_candidate_transcripts);
res.string = MetadataToJSON(result);
STT_FreeMetadata(result);
} else if (stream_size > 0) {
StreamingState* ctx;
int status = STT_CreateStream(aCtx, &ctx);
Expand Down
43 changes: 41 additions & 2 deletions native_client/coqui-stt.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,28 @@ typedef struct CandidateTranscript {
* contributed to the creation of this transcript.
*/
const double confidence;

} CandidateTranscript;

/**
* @brief An structure to contain emissions (the softmax output of individual
* timesteps) from the acoustic model.
*
* @member The layout of the emissions member is time major, thus to access the
* probability of symbol j at timestep i you would use
* emissions[i * num_symbols + j]
*/
typedef struct AcousticModelEmissions {
/** number of symbols in the alphabet, including CTC blank */
int num_symbols;
/** num_symbols long array of NUL-terminated strings */
const char **symbols;
/** total number of timesteps */
int num_timesteps;
/** num_timesteps long array, each pointer is a num_symbols long array */
const double *emissions;
} AcousticModelEmissions;

/**
* @brief An array of CandidateTranscript objects computed by the model.
*/
Expand All @@ -61,6 +81,8 @@ typedef struct Metadata {
const CandidateTranscript* const transcripts;
/** Size of the transcripts array */
const unsigned int num_transcripts;
/** Logits and information to decode them **/
const AcousticModelEmissions* const emissions;
} Metadata;

#endif /* SWIG_ERRORS_ONLY */
Expand Down Expand Up @@ -296,14 +318,31 @@ Metadata* STT_SpeechToTextWithMetadata(ModelState* aCtx,
unsigned int aNumResults);

/**
* @brief Create a new streaming inference state. The streaming state returned
* by this function can then be passed to {@link STT_FeedAudioContent()}
* @brief Use the Coqui STT model to generate emissions (the softmax output of individual
* timesteps).
* by this function can then be passed to {@link STT_CreateStream()}
* and {@link STT_FinishStream()}.
*
* @param aCtx The ModelState pointer for the model to use.
* @param[out] retval an opaque pointer that represents the streaming state. Can
* be NULL if an error occurs.
*
* @return probability of symbol j at timestep i you would use
* emissions[i * num_symbols + j]
*/
STT_EXPORT
Metadata* STT_SpeechToTextWithEmissions(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize,
unsigned int aNumResults);

/**
* @brief Create a new streaming inference state. The streaming state returned
* by this function can then be passed to {@link STT_FeedAudioContent()}
* and {@link STT_FinishStream()}.
*
* @param aCtx The ModelState pointer for the model to use.
*
* @return Zero for success, non-zero on failure.
*/
STT_EXPORT
Expand Down
1 change: 1 addition & 0 deletions native_client/ctcdecode/output.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ struct Output {
double confidence;
std::vector<unsigned int> tokens;
std::vector<unsigned int> timesteps;
std::vector<std::vector<std::pair<int, double>>> probs;
};

struct FlashlightOutput {
Expand Down
1 change: 1 addition & 0 deletions native_client/modelstate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ ModelState::decode_metadata(const DecoderState& state,
Metadata metadata {
transcripts, // transcripts
num_returned, // num_transcripts
NULL,
};
memcpy(ret, &metadata, sizeof(Metadata));
return ret;
Expand Down
147 changes: 143 additions & 4 deletions native_client/stt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ struct StreamingState {
vector<float> batch_buffer_;
vector<float> previous_state_c_;
vector<float> previous_state_h_;
bool keep_emissions_ = false;

vector<double> probs_;
ModelState* model_;
DecoderState decoder_state_;

Expand Down Expand Up @@ -134,7 +136,42 @@ StreamingState::intermediateDecode() const
Metadata*
StreamingState::intermediateDecodeWithMetadata(unsigned int num_results) const
{
return model_->decode_metadata(decoder_state_, num_results);
Metadata *m = model_->decode_metadata(decoder_state_, num_results);

if (keep_emissions_) {

const size_t alphabet_size = model_->alphabet_.GetSize();
const int num_timesteps = probs_.size() / (ModelState::BATCH_SIZE * (alphabet_size + 1));

AcousticModelEmissions* emissions = (AcousticModelEmissions*)malloc(sizeof(AcousticModelEmissions));

emissions->num_symbols = alphabet_size;
emissions->num_timesteps = num_timesteps;
emissions->symbols = (const char**)malloc(sizeof(char*)*alphabet_size + 1);
for (int i = 0; i < alphabet_size; i++) {
emissions->symbols[i] = strdup(model_->alphabet_.DecodeSingle(i).c_str());
}
emissions->symbols[alphabet_size] = strdup("\t");

double* probs = (double*)malloc(sizeof(double)*(alphabet_size + 1)*num_timesteps);
memcpy(probs, probs_.data(), sizeof(double)*(alphabet_size + 1)*num_timesteps);

emissions->emissions = probs;

Metadata* ret = (Metadata*)malloc(sizeof(Metadata));

Metadata metadata {
m->transcripts, // transcripts
m->num_transcripts, // num_transcripts
emissions,
};

memcpy(ret, &metadata, sizeof(Metadata));

return ret;
}

return m;
}

char*
Expand All @@ -148,7 +185,42 @@ Metadata*
StreamingState::finishStreamWithMetadata(unsigned int num_results)
{
flushBuffers(true);
return model_->decode_metadata(decoder_state_, num_results);
Metadata *m = model_->decode_metadata(decoder_state_, num_results);

if (keep_emissions_) {

const size_t alphabet_size = model_->alphabet_.GetSize();
const int num_timesteps = probs_.size() / (ModelState::BATCH_SIZE * (alphabet_size + 1));

AcousticModelEmissions* emissions = (AcousticModelEmissions*)malloc(sizeof(AcousticModelEmissions));

emissions->num_symbols = alphabet_size;
emissions->num_timesteps = num_timesteps;
emissions->symbols = (const char**)malloc(sizeof(char*)*alphabet_size + 1);
for (int i = 0; i < alphabet_size; i++) {
emissions->symbols[i] = strdup(model_->alphabet_.DecodeSingle(i).c_str());
}
emissions->symbols[alphabet_size] = strdup("\t");

double* probs = (double*)malloc(sizeof(double)*(alphabet_size + 1)*num_timesteps);
memcpy(probs, probs_.data(), sizeof(double)*(alphabet_size + 1)*num_timesteps);

emissions->emissions = probs;

Metadata* ret = (Metadata*)malloc(sizeof(Metadata));

Metadata metadata {
m->transcripts, // transcripts
m->num_transcripts, // num_transcripts
emissions,
};

memcpy(ret, &metadata, sizeof(Metadata));

return ret;
}

return m;
}

void
Expand Down Expand Up @@ -253,7 +325,9 @@ StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)

// Convert logits to double
vector<double> inputs(logits.begin(), logits.end());

if (keep_emissions_) {
probs_ = inputs;
}
decoder_state_.next(inputs.data(),
n_frames,
num_classes);
Expand Down Expand Up @@ -476,6 +550,41 @@ STT_CreateStream(ModelState* aCtx,
return STT_ERR_OK;
}

int
CreateStreamWithEmissions(ModelState* aCtx,
StreamingState** retval)
{
*retval = nullptr;

std::unique_ptr<StreamingState> ctx(new StreamingState());
if (!ctx) {
std::cerr << "Could not allocate streaming state." << std::endl;
return STT_ERR_FAIL_CREATE_STREAM;
}

ctx->audio_buffer_.reserve(aCtx->audio_win_len_);
ctx->mfcc_buffer_.reserve(aCtx->mfcc_feats_per_timestep_);
ctx->mfcc_buffer_.resize(aCtx->n_features_*aCtx->n_context_, 0.f);
ctx->batch_buffer_.reserve(aCtx->n_steps_ * aCtx->mfcc_feats_per_timestep_);
ctx->previous_state_c_.resize(aCtx->state_size_, 0.f);
ctx->previous_state_h_.resize(aCtx->state_size_, 0.f);
ctx->model_ = aCtx;
ctx->keep_emissions_ = true;

const int cutoff_top_n = 40;
const double cutoff_prob = 1.0;

ctx->decoder_state_.init(aCtx->alphabet_,
aCtx->beam_width_,
cutoff_prob,
cutoff_top_n,
aCtx->scorer_,
aCtx->hot_words_);

*retval = ctx.release();
return STT_ERR_OK;
}

void
STT_FeedAudioContent(StreamingState* aSctx,
const short* aBuffer,
Expand Down Expand Up @@ -562,6 +671,22 @@ STT_SpeechToTextWithMetadata(ModelState* aCtx,
return STT_FinishStreamWithMetadata(ctx, aNumResults);
}

Metadata*
STT_SpeechToTextWithEmissions(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize,
unsigned int aNumResults)
{
StreamingState* ctx;
int status = CreateStreamWithEmissions(aCtx, &ctx);
if (status != STT_ERR_OK) {
return nullptr;
}
STT_FeedAudioContent(ctx, aBuffer, aBufferSize);

return STT_FinishStreamWithMetadata(ctx, aNumResults);
}

void
STT_FreeStream(StreamingState* aSctx)
{
Expand All @@ -581,10 +706,24 @@ STT_FreeMetadata(Metadata* m)
}

free((void*)m->transcripts);

// Clean up logits if they are not NULL
if (m->emissions) {

if (m->emissions->symbols) {
for (int i = 0; i < m->emissions->num_symbols + 1; i++) {
free((void*)m->emissions->symbols[i]);
}
free((void*)m->emissions->symbols);
}
if (m->emissions->emissions) {
free((void*)m->emissions->emissions);
}
free((void*)m->emissions);
}
free(m);
}
}

void
STT_FreeString(char* str)
{
Expand Down

0 comments on commit e7d2af9

Please sign in to comment.