Improved tokenization by taking dictionary into account

This commit is contained in:
Daniel Wolf 2016-06-25 21:52:04 +02:00
parent 8502256241
commit c9b17e1937
4 changed files with 60 additions and 20 deletions

View File

@ -231,10 +231,14 @@ optional<BoundedTimeline<Phone>> getPhoneAlignment(
return result; return result;
} }
bool dictionaryContains(dict_t& dictionary, const string& word) {
return dict_wordid(&dictionary, word.c_str()) != BAD_S3WID;
}
void addMissingDictionaryWords(const vector<string>& words, ps_decoder_t& decoder) { void addMissingDictionaryWords(const vector<string>& words, ps_decoder_t& decoder) {
map<string, string> missingPronunciations; map<string, string> missingPronunciations;
for (const string& word : words) { for (const string& word : words) {
if (dict_wordid(decoder.dict, word.c_str()) == BAD_S3WID) { if (!dictionaryContains(*decoder.dict, word)) {
string pronunciation; string pronunciation;
for (Phone phone : wordToPhones(word)) { for (Phone phone : wordToPhones(word)) {
if (pronunciation.length() > 0) pronunciation += " "; if (pronunciation.length() > 0) pronunciation += " ";
@ -287,7 +291,7 @@ BoundedTimeline<Phone> detectPhones(
lambda_unique_ptr<ngram_model_t> languageModel; lambda_unique_ptr<ngram_model_t> languageModel;
if (dialog) { if (dialog) {
// Create dialog-specific language model // Create dialog-specific language model
vector<string> words = tokenizeText(*dialog); vector<string> words = tokenizeText(*dialog, [&](const string& word) { return dictionaryContains(*decoder->dict, word); });
words.insert(words.begin(), "<s>"); words.insert(words.begin(), "<s>");
words.push_back("</s>"); words.push_back("</s>");
languageModel = createLanguageModel(words, *decoder->lmath); languageModel = createLanguageModel(words, *decoder->lmath);

View File

@ -15,6 +15,8 @@ using std::string;
using std::vector; using std::vector;
using std::regex; using std::regex;
using std::pair; using std::pair;
using boost::optional;
using std::function;
lambda_unique_ptr<cst_voice> createDummyVoice() { lambda_unique_ptr<cst_voice> createDummyVoice() {
lambda_unique_ptr<cst_voice> voice(new_voice(), [](cst_voice* voice) { delete_voice(voice); }); lambda_unique_ptr<cst_voice> voice(new_voice(), [](cst_voice* voice) { delete_voice(voice); });
@ -51,7 +53,27 @@ vector<string> tokenizeViaFlite(const string& text) {
return result; return result;
} }
vector<string> tokenizeText(const u32string& text) { optional<string> findSimilarDictionaryWord(const string& word, function<bool(const string&)> dictionaryContains) {
for (bool addPeriod : { false, true }) {
for (int apostropheIndex = -1; apostropheIndex <= static_cast<int>(word.size()); ++apostropheIndex) {
string modified = word;
if (apostropheIndex != -1) {
modified.insert(apostropheIndex, "'");
}
if (addPeriod) {
modified += ".";
}
if (dictionaryContains(modified)) {
return modified;
}
}
}
return boost::none;
}
vector<string> tokenizeText(const u32string& text, function<bool(const string&)> dictionaryContains) {
vector<string> words = tokenizeViaFlite(toASCII(text)); vector<string> words = tokenizeViaFlite(toASCII(text));
// Join words separated by apostophes // Join words separated by apostophes
@ -63,7 +85,7 @@ vector<string> tokenizeText(const u32string& text) {
} }
// Turn some symbols into words, remove the rest // Turn some symbols into words, remove the rest
vector<pair<regex, string>> replacements { const static vector<pair<regex, string>> replacements {
{ regex("&"), "and" }, { regex("&"), "and" },
{ regex("\\*"), "times" }, { regex("\\*"), "times" },
{ regex("\\+"), "plus" }, { regex("\\+"), "plus" },
@ -73,12 +95,22 @@ vector<string> tokenizeText(const u32string& text) {
}; };
for (size_t i = 0; i < words.size(); ++i) { for (size_t i = 0; i < words.size(); ++i) {
for (const auto& replacement : replacements) { for (const auto& replacement : replacements) {
words[i] = std::regex_replace(words[i], replacement.first, replacement.second); words[i] = regex_replace(words[i], replacement.first, replacement.second);
} }
} }
// Remove empty words // Remove empty words
words.erase(std::remove_if(words.begin(), words.end(), [](const string& s) { return s.empty(); }), words.end()); words.erase(std::remove_if(words.begin(), words.end(), [](const string& s) { return s.empty(); }), words.end());
// Try to replace words that are not in the dictionary with similar ones that are
for (size_t i = 0; i < words.size(); ++i) {
if (!dictionaryContains(words[i])) {
optional<string> modifiedWord = findSimilarDictionaryWord(words[i], dictionaryContains);
if (modifiedWord) {
words[i] = *modifiedWord;
}
}
}
return words; return words;
} }

View File

@ -1,6 +1,6 @@
#pragma once #pragma once
#include <vector> #include <vector>
#include <string> #include <functional>
std::vector<std::string> tokenizeText(const std::u32string& text); std::vector<std::string> tokenizeText(const std::u32string& text, std::function<bool(const std::string&)> dictionaryContains);

View File

@ -1,6 +1,7 @@
#include <gmock/gmock.h> #include <gmock/gmock.h>
#include "tokenization.h" #include "tokenization.h"
#include <regex> #include <regex>
#include <unordered_set>
using namespace testing; using namespace testing;
using std::string; using std::string;
@ -8,48 +9,51 @@ using std::u32string;
using std::vector; using std::vector;
using std::regex; using std::regex;
bool returnTrue(const string&) {
return true;
}
TEST(tokenizeText, simpleCases) { TEST(tokenizeText, simpleCases) {
EXPECT_THAT(tokenizeText(U""), IsEmpty()); EXPECT_THAT(tokenizeText(U"", returnTrue), IsEmpty());
EXPECT_THAT(tokenizeText(U" \t\n\r\n "), IsEmpty()); EXPECT_THAT(tokenizeText(U" \t\n\r\n ", returnTrue), IsEmpty());
EXPECT_THAT( EXPECT_THAT(
tokenizeText(U"Wit is educated insolence."), tokenizeText(U"Wit is educated insolence.", returnTrue),
ElementsAre("wit", "is", "educated", "insolence") ElementsAre("wit", "is", "educated", "insolence")
); );
} }
TEST(tokenizeText, numbers) { TEST(tokenizeText, numbers) {
EXPECT_THAT( EXPECT_THAT(
tokenizeText(U"Henry V died at 36."), tokenizeText(U"Henry V died at 36.", returnTrue),
ElementsAre("henry", "the", "fifth", "died", "at", "thirty", "six") ElementsAre("henry", "the", "fifth", "died", "at", "thirty", "six")
); );
EXPECT_THAT( EXPECT_THAT(
tokenizeText(U"I spent $4.50 on gum."), tokenizeText(U"I spent $4.50 on gum.", returnTrue),
ElementsAre("i", "spent", "four", "dollars", "fifty", "cents", "on", "gum") ElementsAre("i", "spent", "four", "dollars", "fifty", "cents", "on", "gum")
); );
EXPECT_THAT( EXPECT_THAT(
tokenizeText(U"I was born in 1982."), tokenizeText(U"I was born in 1982.", returnTrue),
ElementsAre("i", "was", "born", "in", "nineteen", "eighty", "two") ElementsAre("i", "was", "born", "in", "nineteen", "eighty", "two")
); );
} }
TEST(tokenizeText, abbreviations) { TEST(tokenizeText, abbreviations) {
EXPECT_THAT( EXPECT_THAT(
tokenizeText(U"I live on Dr. Dolittle Dr."), tokenizeText(U"Prof. Foo lives on Dr. Dolittle Dr.", [](const string& word) { return word == "prof."; }),
ElementsAre("i", "live", "on", "doctor", "dolittle", "drive") ElementsAre("prof.", "foo", "lives", "on", "doctor", "dolittle", "drive")
); );
} }
TEST(tokenizeText, apostrophes) { TEST(tokenizeText, apostrophes) {
// HACK: "wouldn't" really should not become "wouldnt"!
EXPECT_THAT( EXPECT_THAT(
tokenizeText(U"'Tis said he'd wish'd for a 'bus 'cause he wouldn't walk."), tokenizeText(U"'Tis said he'd wish'd for a 'bus 'cause he wouldn't walk.", [](const string& word) { return word == "wouldn't"; }),
ElementsAreArray(vector<string>{ "tis", "said", "he'd", "wish'd", "for", "a", "bus", "cause", "he", "wouldnt", "walk" }) ElementsAreArray(vector<string>{ "tis", "said", "he'd", "wish'd", "for", "a", "bus", "cause", "he", "wouldn't", "walk" })
); );
} }
TEST(tokenizeText, math) { TEST(tokenizeText, math) {
EXPECT_THAT( EXPECT_THAT(
tokenizeText(U"'1+2*3=7"), tokenizeText(U"'1+2*3=7", returnTrue),
ElementsAre("one", "plus", "two", "times", "three", "equals", "seven") ElementsAre("one", "plus", "two", "times", "three", "equals", "seven")
); );
} }
@ -64,7 +68,7 @@ TEST(tokenizeText, wordsUseLimitedCharacters) {
} }
regex legal("^[a-z']+$"); regex legal("^[a-z']+$");
auto words = tokenizeText(input); auto words = tokenizeText(input, returnTrue);
for (const string& word : words) { for (const string& word : words) {
EXPECT_TRUE(std::regex_match(word, legal)) << word; EXPECT_TRUE(std::regex_match(word, legal)) << word;
} }