Restored dialog option, this time based on language model

This approach should be more robust and error-tolerant.
This commit is contained in:
Daniel Wolf 2016-06-03 21:07:49 +02:00
parent 4ed5908627
commit 0d488e8de2
11 changed files with 357 additions and 68 deletions

View File

@ -64,7 +64,7 @@ target_compile_options(cppFormat PRIVATE ${disableWarningsFlags})
set_target_properties(cppFormat PROPERTIES FOLDER lib) set_target_properties(cppFormat PROPERTIES FOLDER lib)
# ... sphinxbase # ... sphinxbase
include_directories(SYSTEM "lib/sphinxbase-5prealpha-2015-08-05/include") include_directories(SYSTEM "lib/sphinxbase-5prealpha-2015-08-05/include" "lib/sphinxbase-5prealpha-2015-08-05/src")
FILE(GLOB_RECURSE sphinxbaseFiles "lib/sphinxbase-5prealpha-2015-08-05/src/libsphinxbase/*.c") FILE(GLOB_RECURSE sphinxbaseFiles "lib/sphinxbase-5prealpha-2015-08-05/src/libsphinxbase/*.c")
add_library(sphinxbase ${sphinxbaseFiles}) add_library(sphinxbase ${sphinxbaseFiles})
target_compile_options(sphinxbase PRIVATE ${disableWarningsFlags}) target_compile_options(sphinxbase PRIVATE ${disableWarningsFlags})
@ -192,6 +192,8 @@ set(SOURCE_FILES
src/Exporter.cpp src/Exporter.h src/Exporter.cpp src/Exporter.h
src/tokenization.cpp src/tokenization.h src/tokenization.cpp src/tokenization.h
src/g2p.cpp src/g2p.h src/g2p.cpp src/g2p.h
src/languageModels.cpp src/languageModels.h
src/tupleHash.h
) )
add_executable(rhubarb ${SOURCE_FILES}) add_executable(rhubarb ${SOURCE_FILES})
target_link_libraries(rhubarb ${Boost_LIBRARIES} cppFormat sphinxbase pocketSphinx flite) target_link_libraries(rhubarb ${Boost_LIBRARIES} cppFormat sphinxbase pocketSphinx flite)

183
src/languageModels.cpp Normal file
View File

@ -0,0 +1,183 @@
#include "languageModels.h"
#include <boost/range/adaptor/map.hpp>
#include <vector>
#include <regex>
#include <map>
#include <tuple>
#include "platformTools.h"
#include <boost/filesystem/fstream.hpp>
#include "appInfo.h"
#include <cmath>
#include <gsl_util.h>
using std::string;
using std::u32string;
using std::vector;
using std::regex;
using std::map;
using std::tuple;
using std::make_tuple;
using std::get;
using std::endl;
using boost::filesystem::path;
using unigram_t = string;
using bigram_t = tuple<string, string>;
using trigram_t = tuple<string, string, string>;
map<unigram_t, int> getUnigramCounts(const vector<string>& words) {
map<unigram_t, int> unigramCounts;
for (const unigram_t& unigram : words) {
++unigramCounts[unigram];
}
return unigramCounts;
}
map<bigram_t, int> getBigramCounts(const vector<string>& words) {
map<bigram_t, int> bigramCounts;
for (auto it = words.begin(); it < words.end() - 1; ++it) {
++bigramCounts[bigram_t(*it, *(it + 1))];
}
return bigramCounts;
}
map<trigram_t, int> getTrigramCounts(const vector<string>& words) {
map<trigram_t, int> trigramCounts;
if (words.size() >= 3) {
for (auto it = words.begin(); it < words.end() - 2; ++it) {
++trigramCounts[trigram_t(*it, *(it + 1), *(it + 2))];
}
}
return trigramCounts;
}
map<unigram_t, double> getUnigramProbabilities(const vector<string>& words, const map<unigram_t, int>& unigramCounts, const double deflator) {
map<unigram_t, double> unigramProbabilities;
for (const auto& pair : unigramCounts) {
unigram_t unigram = get<0>(pair);
int unigramCount = get<1>(pair);
unigramProbabilities[unigram] = double(unigramCount) / words.size() * deflator;
}
return unigramProbabilities;
}
map<bigram_t, double> getBigramProbabilities(const map<unigram_t, int>& unigramCounts, const map<bigram_t, int>& bigramCounts, const double deflator) {
map<bigram_t, double> bigramProbabilities;
for (const auto& pair : bigramCounts) {
bigram_t bigram = get<0>(pair);
int bigramCount = get<1>(pair);
int unigramPrefixCount = unigramCounts.at(get<0>(bigram));
bigramProbabilities[bigram] = double(bigramCount) / unigramPrefixCount * deflator;
}
return bigramProbabilities;
}
map<trigram_t, double> getTrigramProbabilities(const map<bigram_t, int>& bigramCounts, const map<trigram_t, int>& trigramCounts, const double deflator) {
map<trigram_t, double> trigramProbabilities;
for (const auto& pair : trigramCounts) {
trigram_t trigram = get<0>(pair);
int trigramCount = get<1>(pair);
int bigramPrefixCount = bigramCounts.at(bigram_t(get<0>(trigram), get<1>(trigram)));
trigramProbabilities[trigram] = double(trigramCount) / bigramPrefixCount * deflator;
}
return trigramProbabilities;
}
map<unigram_t, double> getUnigramBackoffWeights(
const map<unigram_t, int>& unigramCounts,
const map<unigram_t, double>& unigramProbabilities,
const map<bigram_t, int>& bigramCounts,
const double discountMass)
{
map<unigram_t, double> unigramBackoffWeights;
for (const unigram_t& unigram : unigramCounts | boost::adaptors::map_keys) {
double denominator = 1;
for (const bigram_t& bigram : bigramCounts | boost::adaptors::map_keys) {
if (get<0>(bigram) == unigram) {
denominator -= unigramProbabilities.at(get<1>(bigram));
}
}
unigramBackoffWeights[unigram] = discountMass / denominator;
}
return unigramBackoffWeights;
}
map<bigram_t, double> getBigramBackoffWeights(
const map<bigram_t, int>& bigramCounts,
const map<bigram_t, double>& bigramProbabilities,
const map<trigram_t, int>& trigramCounts,
const double discountMass)
{
map<bigram_t, double> bigramBackoffWeights;
for (const bigram_t& bigram : bigramCounts | boost::adaptors::map_keys) {
double denominator = 1;
for (const trigram_t& trigram : trigramCounts | boost::adaptors::map_keys) {
if (bigram_t(get<0>(trigram), get<1>(trigram)) == bigram) {
denominator -= bigramProbabilities.at(bigram_t(get<1>(trigram), get<2>(trigram)));
}
}
bigramBackoffWeights[bigram] = discountMass / denominator;
}
return bigramBackoffWeights;
}
void createLanguageModelFile(const vector<string>& words, path filePath) {
const double discountMass = 0.5;
const double deflator = 1.0 - discountMass;
map<unigram_t, int> unigramCounts = getUnigramCounts(words);
map<bigram_t, int> bigramCounts = getBigramCounts(words);
map<trigram_t, int> trigramCounts = getTrigramCounts(words);
map<unigram_t, double> unigramProbabilities = getUnigramProbabilities(words, unigramCounts, deflator);
map<bigram_t, double> bigramProbabilities = getBigramProbabilities(unigramCounts, bigramCounts, deflator);
map<trigram_t, double> trigramProbabilities = getTrigramProbabilities(bigramCounts, trigramCounts, deflator);
map<unigram_t, double> unigramBackoffWeights = getUnigramBackoffWeights(unigramCounts, unigramProbabilities, bigramCounts, discountMass);
map<bigram_t, double> bigramBackoffWeights = getBigramBackoffWeights(bigramCounts, bigramProbabilities, trigramCounts, discountMass);
boost::filesystem::ofstream file(filePath);
file << "Generated by " << appName << " " << appVersion << endl << endl;
file << "\\data\\" << endl;
file << "ngram 1=" << unigramCounts.size() << endl;
file << "ngram 2=" << bigramCounts.size() << endl;
file << "ngram 3=" << trigramCounts.size() << endl << endl;
file.setf(std::ios::fixed, std::ios::floatfield);
file.precision(4);
file << "\\1-grams:" << endl;
for (const unigram_t& unigram : unigramCounts | boost::adaptors::map_keys) {
file << log10(unigramProbabilities.at(unigram))
<< " " << unigram
<< " " << log10(unigramBackoffWeights.at(unigram)) << endl;
}
file << endl;
file << "\\2-grams:" << endl;
for (const bigram_t& bigram : bigramCounts | boost::adaptors::map_keys) {
file << log10(bigramProbabilities.at(bigram))
<< " " << get<0>(bigram) << " " << get<1>(bigram)
<< " " << log10(bigramBackoffWeights.at(bigram)) << endl;
}
file << endl;
file << "\\3-grams:" << endl;
for (const trigram_t& trigram : trigramCounts | boost::adaptors::map_keys) {
file << log10(trigramProbabilities.at(trigram))
<< " " << get<0>(trigram) << " " << get<1>(trigram) << " " << get<2>(trigram) << endl;
}
file << endl;
file << "\\end\\" << endl;
}
lambda_unique_ptr<ngram_model_t> createLanguageModel(const vector<string>& words, logmath_t& logMath) {
path tempFilePath = getTempFilePath();
createLanguageModelFile(words, tempFilePath);
auto deleteTempFile = gsl::finally([&]() { boost::filesystem::remove(tempFilePath); });
return lambda_unique_ptr<ngram_model_t>(
ngram_model_read(nullptr, tempFilePath.string().c_str(), NGRAM_ARPA, &logMath),
[](ngram_model_t* lm) { ngram_model_free(lm); });
}

6
src/languageModels.h Normal file
View File

@ -0,0 +1,6 @@
#pragma once
#include <sphinxbase/ngram_model.h>
#include <vector>
#include "tools.h"
lambda_unique_ptr<ngram_model_t> createLanguageModel(const std::vector<std::string>& words, logmath_t& logMath);

View File

@ -12,9 +12,12 @@
#include <gsl_util.h> #include <gsl_util.h>
#include "Exporter.h" #include "Exporter.h"
#include "ContinuousTimeline.h" #include "ContinuousTimeline.h"
#include <boost/filesystem/operations.hpp>
#include "stringTools.h"
using std::exception; using std::exception;
using std::string; using std::string;
using std::u32string;
using std::vector; using std::vector;
using std::unique_ptr; using std::unique_ptr;
using std::make_unique; using std::make_unique;
@ -75,6 +78,25 @@ void addFileSink(path path, logging::Level minLevel) {
logging::addSink(levelFilter); logging::addSink(levelFilter);
} }
u32string readTextFile(path filePath) {
if (!exists(filePath)) {
throw std::invalid_argument(fmt::format("File {} does not exist.", filePath));
}
try {
boost::filesystem::ifstream file;
file.exceptions(std::ifstream::failbit | std::ifstream::badbit);
file.open(filePath);
string utf8Text((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
try {
return utf8ToUtf32(utf8Text);
} catch (...) {
std::throw_with_nested(std::runtime_error(fmt::format("File encoding is not ASCII or UTF-8.", filePath)));
}
} catch (...) {
std::throw_with_nested(std::runtime_error(fmt::format("Error reading file {0}.", filePath)));
}
}
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
auto pausableStderrSink = addPausableStdErrSink(logging::Level::Warn); auto pausableStderrSink = addPausableStdErrSink(logging::Level::Warn);
pausableStderrSink->pause(); pausableStderrSink->pause();
@ -88,6 +110,7 @@ int main(int argc, char *argv[]) {
tclap::ValuesConstraint<logging::Level> logLevelConstraint(logLevels); tclap::ValuesConstraint<logging::Level> logLevelConstraint(logLevels);
tclap::ValueArg<logging::Level> logLevel("", "logLevel", "The minimum log level to log", false, logging::Level::Debug, &logLevelConstraint, cmd); tclap::ValueArg<logging::Level> logLevel("", "logLevel", "The minimum log level to log", false, logging::Level::Debug, &logLevelConstraint, cmd);
tclap::ValueArg<string> logFileName("", "logFile", "The log file path.", false, string(), "string", cmd); tclap::ValueArg<string> logFileName("", "logFile", "The log file path.", false, string(), "string", cmd);
tclap::ValueArg<string> dialogFile("d", "dialogFile", "A file containing the text of the dialog.", false, string(), "string", cmd);
auto exportFormats = vector<ExportFormat>(ExportFormatConverter::get().getValues()); auto exportFormats = vector<ExportFormat>(ExportFormatConverter::get().getValues());
tclap::ValuesConstraint<ExportFormat> exportFormatConstraint(exportFormats); tclap::ValuesConstraint<ExportFormat> exportFormatConstraint(exportFormats);
tclap::ValueArg<ExportFormat> exportFormat("f", "exportFormat", "The export format.", false, ExportFormat::TSV, &exportFormatConstraint, cmd); tclap::ValueArg<ExportFormat> exportFormat("f", "exportFormat", "The export format.", false, ExportFormat::TSV, &exportFormatConstraint, cmd);
@ -117,6 +140,7 @@ int main(int argc, char *argv[]) {
ProgressBar progressBar; ProgressBar progressBar;
phones = detectPhones( phones = detectPhones(
createAudioStream(inputFileName.getValue()), createAudioStream(inputFileName.getValue()),
dialogFile.isSet() ? readTextFile(path(dialogFile.getValue())) : boost::optional<u32string>(),
progressBar); progressBar);
} }
std::cerr << "Done" << std::endl; std::cerr << "Done" << std::endl;

View File

@ -1,6 +1,5 @@
#include <iostream> #include <iostream>
#include <boost/filesystem.hpp> #include <boost/filesystem.hpp>
#include <boost/algorithm/string.hpp>
#include "phoneExtraction.h" #include "phoneExtraction.h"
#include "audio/SampleRateConverter.h" #include "audio/SampleRateConverter.h"
#include "platformTools.h" #include "platformTools.h"
@ -14,6 +13,9 @@
#include <Timeline.h> #include <Timeline.h>
#include <audio/voiceActivityDetection.h> #include <audio/voiceActivityDetection.h>
#include <audio/AudioStreamSegment.h> #include <audio/AudioStreamSegment.h>
#include "languageModels.h"
#include "tokenization.h"
#include "g2p.h"
extern "C" { extern "C" {
#include <pocketsphinx.h> #include <pocketsphinx.h>
@ -35,33 +37,34 @@ using std::function;
using std::regex; using std::regex;
using std::regex_replace; using std::regex_replace;
using std::chrono::duration; using std::chrono::duration;
using boost::optional;
using std::u32string;
constexpr int sphinxSampleRate = 16000; constexpr int sphinxSampleRate = 16000;
lambda_unique_ptr<cmd_ln_t> createConfig(path sphinxModelDirectory) { const path& getSphinxModelDirectory() {
static path sphinxModelDirectory(getBinDirectory() / "res/sphinx");
return sphinxModelDirectory;
}
lambda_unique_ptr<ps_decoder_t> createDecoder() {
lambda_unique_ptr<cmd_ln_t> config( lambda_unique_ptr<cmd_ln_t> config(
cmd_ln_init( cmd_ln_init(
nullptr, ps_args(), true, nullptr, ps_args(), true,
// Set acoustic model // Set acoustic model
"-hmm", (sphinxModelDirectory / "acoustic-model").string().c_str(), "-hmm", (getSphinxModelDirectory() / "acoustic-model").string().c_str(),
// Set language model // Set pronunciation dictionary
"-lm", (sphinxModelDirectory / "en-us.lm.bin").string().c_str(), "-dict", (getSphinxModelDirectory() / "cmudict-en-us.dict").string().c_str(),
// Set pronounciation dictionary
"-dict", (sphinxModelDirectory / "cmudict-en-us.dict").string().c_str(),
// Add noise against zero silence (see http://cmusphinx.sourceforge.net/wiki/faq#qwhy_my_accuracy_is_poor) // Add noise against zero silence (see http://cmusphinx.sourceforge.net/wiki/faq#qwhy_my_accuracy_is_poor)
"-dither", "yes", "-dither", "yes",
nullptr), nullptr),
[](cmd_ln_t* config) { cmd_ln_free_r(config); }); [](cmd_ln_t* config) { cmd_ln_free_r(config); });
if (!config) throw runtime_error("Error creating configuration."); if (!config) throw runtime_error("Error creating configuration.");
return config;
}
lambda_unique_ptr<ps_decoder_t> createSpeechRecognizer(cmd_ln_t& config) {
lambda_unique_ptr<ps_decoder_t> recognizer( lambda_unique_ptr<ps_decoder_t> recognizer(
ps_init(&config), ps_init(config.get()),
[](ps_decoder_t* recognizer) { ps_free(recognizer); }); [](ps_decoder_t* recognizer) { ps_free(recognizer); });
if (!recognizer) throw runtime_error("Error creating speech recognizer."); if (!recognizer) throw runtime_error("Error creating speech decoder.");
return recognizer; return recognizer;
} }
@ -141,32 +144,32 @@ void sphinxLogCallback(void* user_data, err_lvl_t errorLevel, const char* format
logging::log(logLevel, message); logging::log(logLevel, message);
} }
BoundedTimeline<string> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) { BoundedTimeline<string> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t& decoder, ProgressSink& progressSink) {
// Convert audio stream to the exact format PocketSphinx requires // Convert audio stream to the exact format PocketSphinx requires
audioStream = convertSampleRate(std::move(audioStream), sphinxSampleRate); audioStream = convertSampleRate(std::move(audioStream), sphinxSampleRate);
// Restart timing at 0 // Restart timing at 0
ps_start_stream(&recognizer); ps_start_stream(&decoder);
// Start recognition // Start recognition
int error = ps_start_utt(&recognizer); int error = ps_start_utt(&decoder);
if (error) throw runtime_error("Error starting utterance processing for word recognition."); if (error) throw runtime_error("Error starting utterance processing for word recognition.");
// Process entire sound file // Process entire sound file
auto processBuffer = [&recognizer](const vector<int16_t>& buffer) { auto processBuffer = [&decoder](const vector<int16_t>& buffer) {
int searchedFrameCount = ps_process_raw(&recognizer, buffer.data(), buffer.size(), false, false); int searchedFrameCount = ps_process_raw(&decoder, buffer.data(), buffer.size(), false, false);
if (searchedFrameCount < 0) throw runtime_error("Error analyzing raw audio data for word recognition."); if (searchedFrameCount < 0) throw runtime_error("Error analyzing raw audio data for word recognition.");
}; };
processAudioStream(*audioStream.get(), processBuffer, progressSink); processAudioStream(*audioStream.get(), processBuffer, progressSink);
// End recognition // End recognition
error = ps_end_utt(&recognizer); error = ps_end_utt(&decoder);
if (error) throw runtime_error("Error ending utterance processing for word recognition."); if (error) throw runtime_error("Error ending utterance processing for word recognition.");
// Collect words // Collect words
BoundedTimeline<string> result(audioStream->getTruncatedRange()); BoundedTimeline<string> result(audioStream->getTruncatedRange());
int32_t score; int32_t score;
for (ps_seg_t* it = ps_seg_iter(&recognizer, &score); it; it = ps_seg_next(it)) { for (ps_seg_t* it = ps_seg_iter(&decoder, &score); it; it = ps_seg_next(it)) {
const char* word = ps_seg_word(it); const char* word = ps_seg_word(it);
int firstFrame, lastFrame; int firstFrame, lastFrame;
ps_seg_frames(it, &firstFrame, &lastFrame); ps_seg_frames(it, &firstFrame, &lastFrame);
@ -176,35 +179,6 @@ BoundedTimeline<string> recognizeWords(unique_ptr<AudioStream> audioStream, ps_d
return result; return result;
} }
// Splits dialog into words, doing minimal preprocessing.
// A robust solution should use TTS logic to cope with numbers, abbreviations, unknown words etc.
vector<string> extractDialogWords(string dialog) {
// Convert to lower case
boost::algorithm::to_lower(dialog);
// Insert silences where appropriate
dialog = regex_replace(dialog, regex("[,;.:!?] |-"), " <sil> ");
// Remove all undesired characters
dialog = regex_replace(dialog, regex("[^a-z.'\\0-9<>]"), " ");
// Collapse whitespace
dialog = regex_replace(dialog, regex("\\s+"), " ");
// Trim
boost::algorithm::trim(dialog);
// Ugly hack: Remove trailing period
if (boost::algorithm::ends_with(dialog, ".")) {
dialog.pop_back();
}
// Split into words
vector<string> result;
boost::algorithm::split(result, dialog, boost::is_space());
return result;
}
s3wid_t getWordId(const string& word, dict_t& dictionary) { s3wid_t getWordId(const string& word, dict_t& dictionary) {
s3wid_t wordId = dict_wordid(&dictionary, word.c_str()); s3wid_t wordId = dict_wordid(&dictionary, word.c_str());
if (wordId == BAD_S3WID) throw invalid_argument(fmt::format("Unknown word '{}'.", word)); if (wordId == BAD_S3WID) throw invalid_argument(fmt::format("Unknown word '{}'.", word));
@ -214,12 +188,12 @@ s3wid_t getWordId(const string& word, dict_t& dictionary) {
BoundedTimeline<Phone> getPhoneAlignment( BoundedTimeline<Phone> getPhoneAlignment(
const vector<s3wid_t>& wordIds, const vector<s3wid_t>& wordIds,
unique_ptr<AudioStream> audioStream, unique_ptr<AudioStream> audioStream,
ps_decoder_t& recognizer, ps_decoder_t& decoder,
ProgressSink& progressSink) ProgressSink& progressSink)
{ {
// Create alignment list // Create alignment list
lambda_unique_ptr<ps_alignment_t> alignment( lambda_unique_ptr<ps_alignment_t> alignment(
ps_alignment_init(recognizer.d2p), ps_alignment_init(decoder.d2p),
[](ps_alignment_t* alignment) { ps_alignment_free(alignment); }); [](ps_alignment_t* alignment) { ps_alignment_free(alignment); });
if (!alignment) throw runtime_error("Error creating alignment."); if (!alignment) throw runtime_error("Error creating alignment.");
for (s3wid_t wordId : wordIds) { for (s3wid_t wordId : wordIds) {
@ -233,9 +207,9 @@ BoundedTimeline<Phone> getPhoneAlignment(
audioStream = convertSampleRate(std::move(audioStream), sphinxSampleRate); audioStream = convertSampleRate(std::move(audioStream), sphinxSampleRate);
// Create search structure // Create search structure
acmod_t* acousticModel = recognizer.acmod; acmod_t* acousticModel = decoder.acmod;
lambda_unique_ptr<ps_search_t> search( lambda_unique_ptr<ps_search_t> search(
state_align_search_init("state_align", recognizer.config, acousticModel, alignment.get()), state_align_search_init("state_align", decoder.config, acousticModel, alignment.get()),
[](ps_search_t* search) { ps_search_free(search); }); [](ps_search_t* search) { ps_search_free(search); });
if (!search) throw runtime_error("Error creating search."); if (!search) throw runtime_error("Error creating search.");
@ -247,7 +221,7 @@ BoundedTimeline<Phone> getPhoneAlignment(
ps_search_start(search.get()); ps_search_start(search.get());
// Process entire sound file // Process entire sound file
auto processBuffer = [&recognizer, &acousticModel, &search](const vector<int16_t>& buffer) { auto processBuffer = [&decoder, &acousticModel, &search](const vector<int16_t>& buffer) {
const int16* nextSample = buffer.data(); const int16* nextSample = buffer.data();
size_t remainingSamples = buffer.size(); size_t remainingSamples = buffer.size();
while (acmod_process_raw(acousticModel, &nextSample, &remainingSamples, false) > 0) { while (acmod_process_raw(acousticModel, &nextSample, &remainingSamples, false) > 0) {
@ -266,7 +240,7 @@ BoundedTimeline<Phone> getPhoneAlignment(
acmod_end_utt(acousticModel); acmod_end_utt(acousticModel);
// Extract phones with timestamps // Extract phones with timestamps
char** phoneNames = recognizer.dict->mdef->ciname; char** phoneNames = decoder.dict->mdef->ciname;
BoundedTimeline<Phone> result(audioStream->getTruncatedRange()); BoundedTimeline<Phone> result(audioStream->getTruncatedRange());
for (ps_alignment_iter_t* it = ps_alignment_phones(alignment.get()); it; it = ps_alignment_iter_next(it)) { for (ps_alignment_iter_t* it = ps_alignment_phones(alignment.get()); it; it = ps_alignment_iter_next(it)) {
// Get phone // Get phone
@ -285,8 +259,28 @@ BoundedTimeline<Phone> getPhoneAlignment(
return result; return result;
} }
void addMissingDictionaryWords(const vector<string>& words, ps_decoder_t& decoder) {
map<string, string> missingPronunciations;
for (const string& word : words) {
if (dict_wordid(decoder.dict, word.c_str()) == BAD_S3WID) {
string pronunciation;
for (Phone phone : wordToPhones(word)) {
if (pronunciation.length() > 0) pronunciation += " ";
pronunciation += PhoneConverter::get().toString(phone);
}
missingPronunciations[word] = pronunciation;
}
}
for (auto it = missingPronunciations.begin(); it != missingPronunciations.end(); ++it) {
bool isLast = it == --missingPronunciations.end();
logging::infoFormat("Unknown word '{}'. Guessing pronunciation '{}'.", it->first, it->second);
ps_add_word(&decoder, it->first.c_str(), it->second.c_str(), isLast);
}
}
BoundedTimeline<Phone> detectPhones( BoundedTimeline<Phone> detectPhones(
unique_ptr<AudioStream> audioStream, unique_ptr<AudioStream> audioStream,
optional<u32string> dialog,
ProgressSink& progressSink) ProgressSink& progressSink)
{ {
// Pocketsphinx doesn't like empty input // Pocketsphinx doesn't like empty input
@ -305,13 +299,6 @@ BoundedTimeline<Phone> detectPhones(
audioStream = removeDCOffset(std::move(audioStream)); audioStream = removeDCOffset(std::move(audioStream));
try { try {
// Create PocketSphinx configuration
path sphinxModelDirectory(getBinDirectory() / "res/sphinx");
auto config = createConfig(sphinxModelDirectory);
// Create speech recognizer
auto recognizer = createSpeechRecognizer(*config.get());
// Split audio into utterances // Split audio into utterances
BoundedTimeline<void> utterances = detectVoiceActivity(audioStream->clone(true)); BoundedTimeline<void> utterances = detectVoiceActivity(audioStream->clone(true));
@ -323,6 +310,29 @@ BoundedTimeline<Phone> detectPhones(
} }
auto utteranceProgressSinkIt = utteranceProgressSinks.begin(); auto utteranceProgressSinkIt = utteranceProgressSinks.begin();
// Create speech recognizer
auto decoder = createDecoder();
// Set language model
lambda_unique_ptr<ngram_model_t> languageModel;
if (dialog) {
// Create dialog-specific language model
vector<string> words = tokenizeText(*dialog);
words.insert(words.begin(), "<s>");
words.push_back("</s>");
languageModel = createLanguageModel(words, *decoder->lmath);
// Add any dialog-specific words to the dictionary
addMissingDictionaryWords(words, *decoder);
} else {
path modelPath = getSphinxModelDirectory() / "en-us.lm.bin";
languageModel = lambda_unique_ptr<ngram_model_t>(
ngram_model_read(decoder->config, modelPath.string().c_str(), NGRAM_AUTO, decoder->lmath),
[](ngram_model_t* lm) { ngram_model_free(lm); });
}
ps_set_lm(decoder.get(), "lm", languageModel.get());
ps_set_search(decoder.get(), "lm");
BoundedTimeline<Phone> result(audioStream->getTruncatedRange()); BoundedTimeline<Phone> result(audioStream->getTruncatedRange());
for (const auto& timedUtterance : utterances) { for (const auto& timedUtterance : utterances) {
ProgressMerger utteranceProgressMerger(**utteranceProgressSinkIt++); ProgressMerger utteranceProgressMerger(**utteranceProgressSinkIt++);
@ -335,7 +345,7 @@ BoundedTimeline<Phone> detectPhones(
auto streamSegment = createSegment(audioStream->clone(true), timeRange); auto streamSegment = createSegment(audioStream->clone(true), timeRange);
// Get words // Get words
BoundedTimeline<string> words = recognizeWords(streamSegment->clone(true), *recognizer.get(), wordRecognitionProgressSink); BoundedTimeline<string> words = recognizeWords(streamSegment->clone(true), *decoder.get(), wordRecognitionProgressSink);
for (Timed<string> timedWord : words) { for (Timed<string> timedWord : words) {
timedWord.getTimeRange().shift(timedUtterance.getStart()); timedWord.getTimeRange().shift(timedUtterance.getStart());
logging::logTimedEvent("word", timedWord); logging::logTimedEvent("word", timedWord);
@ -344,12 +354,12 @@ BoundedTimeline<Phone> detectPhones(
// Look up words in dictionary // Look up words in dictionary
vector<s3wid_t> wordIds; vector<s3wid_t> wordIds;
for (const auto& timedWord : words) { for (const auto& timedWord : words) {
wordIds.push_back(getWordId(timedWord.getValue(), *recognizer->dict)); wordIds.push_back(getWordId(timedWord.getValue(), *decoder->dict));
} }
if (wordIds.empty()) continue; if (wordIds.empty()) continue;
// Align the words' phones with speech // Align the words' phones with speech
BoundedTimeline<Phone> segmentPhones = getPhoneAlignment(wordIds, std::move(streamSegment), *recognizer.get(), alignmentProgressSink); BoundedTimeline<Phone> segmentPhones = getPhoneAlignment(wordIds, std::move(streamSegment), *decoder.get(), alignmentProgressSink);
segmentPhones.shift(timedUtterance.getStart()); segmentPhones.shift(timedUtterance.getStart());
for (const auto& timedPhone : segmentPhones) { for (const auto& timedPhone : segmentPhones) {
logging::logTimedEvent("phone", timedPhone); logging::logTimedEvent("phone", timedPhone);

View File

@ -8,4 +8,5 @@
BoundedTimeline<Phone> detectPhones( BoundedTimeline<Phone> detectPhones(
std::unique_ptr<AudioStream> audioStream, std::unique_ptr<AudioStream> audioStream,
boost::optional<std::u32string> dialog,
ProgressSink& progressSink); ProgressSink& progressSink);

View File

@ -2,10 +2,14 @@
#include <boost/filesystem/path.hpp> #include <boost/filesystem/path.hpp>
#include <boost/predef.h> #include <boost/predef.h>
#include <format.h> #include <format.h>
#include <boost/uuid/uuid.hpp>
#include <boost/uuid/uuid_generators.hpp>
#include <boost/uuid/uuid_io.hpp>
#include "platformTools.h" #include "platformTools.h"
using boost::filesystem::path; using boost::filesystem::path;
using std::string;
constexpr int InitialBufferSize = 256; constexpr int InitialBufferSize = 256;
@ -129,3 +133,10 @@ path getBinPath() {
path getBinDirectory() { path getBinDirectory() {
return getBinPath().parent_path(); return getBinPath().parent_path();
} }
path getTempFilePath() {
path tempDirectory = boost::filesystem::temp_directory_path();
static auto generateUuid = boost::uuids::random_generator();
string fileName = to_string(generateUuid());
return tempDirectory / fileName;
}

View File

@ -4,3 +4,4 @@
boost::filesystem::path getBinPath(); boost::filesystem::path getBinPath();
boost::filesystem::path getBinDirectory(); boost::filesystem::path getBinDirectory();
boost::filesystem::path getTempFilePath();

View File

@ -1,5 +1,6 @@
#include "stringTools.h" #include "stringTools.h"
#include <boost/algorithm/string/trim.hpp> #include <boost/algorithm/string/trim.hpp>
#include <codecvt>
using std::string; using std::string;
using std::wstring; using std::wstring;
@ -106,3 +107,12 @@ string toASCII(const u32string& s) {
} }
return result; return result;
} }
u32string utf8ToUtf32(const string& s) {
// Visual Studio 2015 has a bug regarding char32_t:
// https://connect.microsoft.com/VisualStudio/feedback/details/1403302/unresolved-external-when-using-codecvt-utf8
// Once VS2016 is out, we can use char32_t instead of uint32_t as type arguments and get rid of the outer conversion.
std::wstring_convert<std::codecvt_utf8<uint32_t>, uint32_t> convert;
return u32string(reinterpret_cast<const char32_t*>(convert.from_bytes(s).c_str()));
}

View File

@ -1,6 +1,5 @@
#pragma once #pragma once
#include <string>
#include <vector> #include <vector>
#include <boost/optional.hpp> #include <boost/optional.hpp>
@ -14,4 +13,6 @@ std::wstring latin1ToWide(const std::string& s);
boost::optional<char> toASCII(char32_t ch); boost::optional<char> toASCII(char32_t ch);
std::string toASCII(const std::u32string& s); std::string toASCII(const std::u32string& s);
std::u32string utf8ToUtf32(const std::string& s);

40
src/tupleHash.h Normal file
View File

@ -0,0 +1,40 @@
#pragma once
#include <tuple>
namespace std {
namespace {
template <typename T>
void hash_combine(size_t& seed, const T& value) {
seed ^= std::hash<T>()(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
// Recursive template code derived from Matthieu M.
template <typename Tuple, size_t Index = tuple_size<Tuple>::value - 1>
struct HashValueImpl {
static void apply(size_t& seed, const Tuple& tuple) {
HashValueImpl<Tuple, Index - 1>::apply(seed, tuple);
hash_combine(seed, std::get<Index>(tuple));
}
};
template <typename Tuple>
struct HashValueImpl<Tuple, 0> {
static void apply(size_t& seed, const Tuple& tuple) {
hash_combine(seed, std::get<0>(tuple));
}
};
}
template <typename ... TT>
struct hash<tuple<TT...>> {
size_t operator()(const tuple<TT...>& tt) const {
size_t seed = 0;
HashValueImpl<tuple<TT...> >::apply(seed, tt);
return seed;
}
};
}