diff --git a/src/phoneRecognition.cpp b/src/phoneRecognition.cpp index b2c5e88..6931c76 100644 --- a/src/phoneRecognition.cpp +++ b/src/phoneRecognition.cpp @@ -45,6 +45,7 @@ using std::chrono::duration; using boost::optional; using std::u32string; using std::chrono::duration_cast; +using std::array; constexpr int sphinxSampleRate = 16000; @@ -237,6 +238,48 @@ void addMissingDictionaryWords(const vector& words, ps_decoder_t& decode } } +lambda_unique_ptr createDefaultLanguageModel(ps_decoder_t& decoder) { + path modelPath = getSphinxModelDirectory() / "en-us.lm.bin"; + lambda_unique_ptr result( + ngram_model_read(decoder.config, modelPath.string().c_str(), NGRAM_AUTO, decoder.lmath), + [](ngram_model_t* lm) { ngram_model_free(lm); }); + if (!result) { + throw runtime_error(fmt::format("Error reading language model from {}.", modelPath)); + } + + return std::move(result); +} + +lambda_unique_ptr createDialogLanguageModel(ps_decoder_t& decoder, const u32string& dialog) { + // Split dialog into normalized words + vector words = tokenizeText(dialog, [&](const string& word) { return dictionaryContains(*decoder.dict, word); }); + + // Add dialog-specific words to the dictionary + addMissingDictionaryWords(words, decoder); + + // Create dialog-specific language model + words.insert(words.begin(), ""); + words.push_back(""); + return createLanguageModel(words, decoder); +} + +lambda_unique_ptr createBiasedLanguageModel(ps_decoder_t& decoder, const u32string& dialog) { + auto defaultLanguageModel = createDefaultLanguageModel(decoder); + auto dialogLanguageModel = createDialogLanguageModel(decoder, dialog); + constexpr int modelCount = 2; + array languageModels{ defaultLanguageModel.get(), dialogLanguageModel.get() }; + array modelNames{ "defaultLM", "dialogLM" }; + array modelWeights{ 0.1f, 0.9f }; + lambda_unique_ptr result( + ngram_model_set_init(nullptr, languageModels.data(), modelNames.data(), modelWeights.data(), modelCount), + [](ngram_model_t* lm) { ngram_model_free(lm); }); + if (!result) { + throw runtime_error("Error creating biased language model."); + } + + return std::move(result); +} + lambda_unique_ptr createDecoder(optional dialog) { lambda_unique_ptr config( cmd_ln_init( @@ -261,22 +304,9 @@ lambda_unique_ptr createDecoder(optional dialog) { if (!decoder) throw runtime_error("Error creating speech decoder."); // Set language model - lambda_unique_ptr languageModel; - if (dialog) { - // Create dialog-specific language model - vector words = tokenizeText(*dialog, [&](const string& word) { return dictionaryContains(*decoder->dict, word); }); - words.insert(words.begin(), ""); - words.push_back(""); - languageModel = createLanguageModel(words, *decoder); - - // Add any dialog-specific words to the dictionary - addMissingDictionaryWords(words, *decoder); - } else { - path modelPath = getSphinxModelDirectory() / "en-us.lm.bin"; - languageModel = lambda_unique_ptr( - ngram_model_read(decoder->config, modelPath.string().c_str(), NGRAM_AUTO, decoder->lmath), - [](ngram_model_t* lm) { ngram_model_free(lm); }); - } + lambda_unique_ptr languageModel(dialog + ? createBiasedLanguageModel(*decoder, *dialog) + : createDefaultLanguageModel(*decoder)); ps_set_lm(decoder.get(), "lm", languageModel.get()); ps_set_search(decoder.get(), "lm");