virtualx-engine/thirdparty/icu4c/common/lstmbe.cpp

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

857 lines
25 KiB
C++
Raw Normal View History

2021-10-28 08:15:28 +02:00
// © 2021 and later: Unicode, Inc. and others.
// License & terms of use: http://www.unicode.org/copyright.html
#include <complex>
2021-10-28 08:15:28 +02:00
#include <utility>
#include "unicode/utypes.h"
#if !UCONFIG_NO_BREAK_ITERATION
#include "brkeng.h"
#include "charstr.h"
#include "cmemory.h"
#include "lstmbe.h"
#include "putilimp.h"
#include "uassert.h"
#include "ubrkimpl.h"
#include "uresimp.h"
#include "uvectr32.h"
#include "uvector.h"
#include "unicode/brkiter.h"
#include "unicode/resbund.h"
#include "unicode/ubrk.h"
#include "unicode/uniset.h"
#include "unicode/ustring.h"
#include "unicode/utf.h"
U_NAMESPACE_BEGIN
// Uncomment the following #define to debug.
// #define LSTM_DEBUG 1
// #define LSTM_VECTORIZER_DEBUG 1
/**
* Interface for reading 1D array.
*/
class ReadArray1D {
public:
virtual ~ReadArray1D();
virtual int32_t d1() const = 0;
virtual float get(int32_t i) const = 0;
#ifdef LSTM_DEBUG
void print() const {
printf("\n[");
for (int32_t i = 0; i < d1(); i++) {
printf("%0.8e ", get(i));
if (i % 4 == 3) printf("\n");
}
printf("]\n");
}
#endif
};
ReadArray1D::~ReadArray1D()
{
}
/**
* Interface for reading 2D array.
*/
class ReadArray2D {
public:
virtual ~ReadArray2D();
virtual int32_t d1() const = 0;
virtual int32_t d2() const = 0;
virtual float get(int32_t i, int32_t j) const = 0;
};
ReadArray2D::~ReadArray2D()
{
}
/**
* A class to index a float array as a 1D Array without owning the pointer or
* copy the data.
*/
class ConstArray1D : public ReadArray1D {
public:
ConstArray1D() : data_(nullptr), d1_(0) {}
ConstArray1D(const float* data, int32_t d1) : data_(data), d1_(d1) {}
virtual ~ConstArray1D();
// Init the object, the object does not own the data nor copy.
// It is designed to directly use data from memory mapped resources.
void init(const int32_t* data, int32_t d1) {
U_ASSERT(IEEE_754 == 1);
data_ = reinterpret_cast<const float*>(data);
d1_ = d1;
}
// ReadArray1D methods.
virtual int32_t d1() const override { return d1_; }
virtual float get(int32_t i) const override {
U_ASSERT(i < d1_);
return data_[i];
}
private:
const float* data_;
int32_t d1_;
};
ConstArray1D::~ConstArray1D()
{
}
/**
* A class to index a float array as a 2D Array without owning the pointer or
* copy the data.
*/
class ConstArray2D : public ReadArray2D {
public:
ConstArray2D() : data_(nullptr), d1_(0), d2_(0) {}
ConstArray2D(const float* data, int32_t d1, int32_t d2)
: data_(data), d1_(d1), d2_(d2) {}
virtual ~ConstArray2D();
// Init the object, the object does not own the data nor copy.
// It is designed to directly use data from memory mapped resources.
void init(const int32_t* data, int32_t d1, int32_t d2) {
U_ASSERT(IEEE_754 == 1);
data_ = reinterpret_cast<const float*>(data);
d1_ = d1;
d2_ = d2;
}
// ReadArray2D methods.
inline int32_t d1() const override { return d1_; }
inline int32_t d2() const override { return d2_; }
float get(int32_t i, int32_t j) const override {
U_ASSERT(i < d1_);
U_ASSERT(j < d2_);
return data_[i * d2_ + j];
}
// Expose the ith row as a ConstArray1D
inline ConstArray1D row(int32_t i) const {
U_ASSERT(i < d1_);
return ConstArray1D(data_ + i * d2_, d2_);
}
private:
const float* data_;
int32_t d1_;
int32_t d2_;
};
ConstArray2D::~ConstArray2D()
{
}
/**
* A class to allocate data as a writable 1D array.
* This is the main class implement matrix operation.
*/
class Array1D : public ReadArray1D {
public:
Array1D() : memory_(nullptr), data_(nullptr), d1_(0) {}
Array1D(int32_t d1, UErrorCode &status)
: memory_(uprv_malloc(d1 * sizeof(float))),
data_((float*)memory_), d1_(d1) {
if (U_SUCCESS(status)) {
if (memory_ == nullptr) {
status = U_MEMORY_ALLOCATION_ERROR;
return;
}
clear();
}
}
virtual ~Array1D();
// A special constructor which does not own the memory but writeable
// as a slice of an array.
Array1D(float* data, int32_t d1)
: memory_(nullptr), data_(data), d1_(d1) {}
// ReadArray1D methods.
virtual int32_t d1() const override { return d1_; }
virtual float get(int32_t i) const override {
U_ASSERT(i < d1_);
return data_[i];
}
// Return the index which point to the max data in the array.
inline int32_t maxIndex() const {
int32_t index = 0;
float max = data_[0];
for (int32_t i = 1; i < d1_; i++) {
if (data_[i] > max) {
max = data_[i];
index = i;
}
}
return index;
}
// Slice part of the array to a new one.
inline Array1D slice(int32_t from, int32_t size) const {
U_ASSERT(from >= 0);
U_ASSERT(from < d1_);
U_ASSERT(from + size <= d1_);
return Array1D(data_ + from, size);
}
// Add dot product of a 1D array and a 2D array into this one.
inline Array1D& addDotProduct(const ReadArray1D& a, const ReadArray2D& b) {
U_ASSERT(a.d1() == b.d1());
U_ASSERT(b.d2() == d1());
for (int32_t i = 0; i < d1(); i++) {
for (int32_t j = 0; j < a.d1(); j++) {
data_[i] += a.get(j) * b.get(j, i);
}
}
return *this;
}
// Hadamard Product the values of another array of the same size into this one.
inline Array1D& hadamardProduct(const ReadArray1D& a) {
U_ASSERT(a.d1() == d1());
for (int32_t i = 0; i < d1(); i++) {
data_[i] *= a.get(i);
}
return *this;
}
// Add the Hadamard Product of two arrays of the same size into this one.
inline Array1D& addHadamardProduct(const ReadArray1D& a, const ReadArray1D& b) {
U_ASSERT(a.d1() == d1());
U_ASSERT(b.d1() == d1());
for (int32_t i = 0; i < d1(); i++) {
data_[i] += a.get(i) * b.get(i);
}
return *this;
}
// Add the values of another array of the same size into this one.
inline Array1D& add(const ReadArray1D& a) {
U_ASSERT(a.d1() == d1());
for (int32_t i = 0; i < d1(); i++) {
data_[i] += a.get(i);
}
return *this;
}
// Assign the values of another array of the same size into this one.
inline Array1D& assign(const ReadArray1D& a) {
U_ASSERT(a.d1() == d1());
for (int32_t i = 0; i < d1(); i++) {
data_[i] = a.get(i);
}
return *this;
}
// Apply tanh to all the elements in the array.
inline Array1D& tanh() {
return tanh(*this);
}
// Apply tanh of a and store into this array.
inline Array1D& tanh(const Array1D& a) {
U_ASSERT(a.d1() == d1());
for (int32_t i = 0; i < d1_; i++) {
data_[i] = std::tanh(a.get(i));
}
return *this;
}
// Apply sigmoid to all the elements in the array.
inline Array1D& sigmoid() {
for (int32_t i = 0; i < d1_; i++) {
data_[i] = 1.0f/(1.0f + expf(-data_[i]));
}
return *this;
}
inline Array1D& clear() {
uprv_memset(data_, 0, d1_ * sizeof(float));
return *this;
}
private:
void* memory_;
float* data_;
int32_t d1_;
};
Array1D::~Array1D()
{
uprv_free(memory_);
}
class Array2D : public ReadArray2D {
public:
Array2D() : memory_(nullptr), data_(nullptr), d1_(0), d2_(0) {}
Array2D(int32_t d1, int32_t d2, UErrorCode &status)
: memory_(uprv_malloc(d1 * d2 * sizeof(float))),
data_((float*)memory_), d1_(d1), d2_(d2) {
if (U_SUCCESS(status)) {
if (memory_ == nullptr) {
status = U_MEMORY_ALLOCATION_ERROR;
return;
}
clear();
}
}
virtual ~Array2D();
// ReadArray2D methods.
virtual int32_t d1() const override { return d1_; }
virtual int32_t d2() const override { return d2_; }
virtual float get(int32_t i, int32_t j) const override {
U_ASSERT(i < d1_);
U_ASSERT(j < d2_);
return data_[i * d2_ + j];
}
inline Array1D row(int32_t i) const {
U_ASSERT(i < d1_);
return Array1D(data_ + i * d2_, d2_);
}
inline Array2D& clear() {
uprv_memset(data_, 0, d1_ * d2_ * sizeof(float));
return *this;
}
private:
void* memory_;
float* data_;
int32_t d1_;
int32_t d2_;
};
Array2D::~Array2D()
{
uprv_free(memory_);
}
typedef enum {
BEGIN,
INSIDE,
END,
SINGLE
} LSTMClass;
typedef enum {
UNKNOWN,
CODE_POINTS,
GRAPHEME_CLUSTER,
} EmbeddingType;
struct LSTMData : public UMemory {
LSTMData(UResourceBundle* rb, UErrorCode &status);
~LSTMData();
UHashtable* fDict;
EmbeddingType fType;
const char16_t* fName;
2021-10-28 08:15:28 +02:00
ConstArray2D fEmbedding;
ConstArray2D fForwardW;
ConstArray2D fForwardU;
ConstArray1D fForwardB;
ConstArray2D fBackwardW;
ConstArray2D fBackwardU;
ConstArray1D fBackwardB;
ConstArray2D fOutputW;
ConstArray1D fOutputB;
private:
UResourceBundle* fBundle;
};
LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status)
: fDict(nullptr), fType(UNKNOWN), fName(nullptr),
fBundle(rb)
{
if (U_FAILURE(status)) {
return;
}
if (IEEE_754 != 1) {
status = U_UNSUPPORTED_ERROR;
return;
}
LocalUResourceBundlePointer embeddings_res(
ures_getByKey(rb, "embeddings", nullptr, &status));
int32_t embedding_size = ures_getInt(embeddings_res.getAlias(), &status);
LocalUResourceBundlePointer hunits_res(
ures_getByKey(rb, "hunits", nullptr, &status));
if (U_FAILURE(status)) return;
int32_t hunits = ures_getInt(hunits_res.getAlias(), &status);
const char16_t* type = ures_getStringByKey(rb, "type", nullptr, &status);
2021-10-28 08:15:28 +02:00
if (U_FAILURE(status)) return;
if (u_strCompare(type, -1, u"codepoints", -1, false) == 0) {
fType = CODE_POINTS;
} else if (u_strCompare(type, -1, u"graphclust", -1, false) == 0) {
fType = GRAPHEME_CLUSTER;
}
fName = ures_getStringByKey(rb, "model", nullptr, &status);
LocalUResourceBundlePointer dataRes(ures_getByKey(rb, "data", nullptr, &status));
if (U_FAILURE(status)) return;
int32_t data_len = 0;
const int32_t* data = ures_getIntVector(dataRes.getAlias(), &data_len, &status);
fDict = uhash_open(uhash_hashUChars, uhash_compareUChars, nullptr, &status);
StackUResourceBundle stackTempBundle;
ResourceDataValue value;
ures_getValueWithFallback(rb, "dict", stackTempBundle.getAlias(), value, status);
ResourceArray stringArray = value.getArray(status);
int32_t num_index = stringArray.getSize();
if (U_FAILURE(status)) { return; }
// put dict into hash
int32_t stringLength;
for (int32_t idx = 0; idx < num_index; idx++) {
stringArray.getValue(idx, value);
const char16_t* str = value.getString(stringLength, status);
2021-10-28 08:15:28 +02:00
uhash_putiAllowZero(fDict, (void*)str, idx, &status);
if (U_FAILURE(status)) return;
#ifdef LSTM_VECTORIZER_DEBUG
printf("Assign [");
while (*str != 0x0000) {
printf("U+%04x ", *str);
str++;
}
printf("] map to %d\n", idx-1);
#endif
}
int32_t mat1_size = (num_index + 1) * embedding_size;
int32_t mat2_size = embedding_size * 4 * hunits;
int32_t mat3_size = hunits * 4 * hunits;
int32_t mat4_size = 4 * hunits;
int32_t mat5_size = mat2_size;
int32_t mat6_size = mat3_size;
int32_t mat7_size = mat4_size;
int32_t mat8_size = 2 * hunits * 4;
#if U_DEBUG
int32_t mat9_size = 4;
U_ASSERT(data_len == mat1_size + mat2_size + mat3_size + mat4_size + mat5_size +
mat6_size + mat7_size + mat8_size + mat9_size);
#endif
fEmbedding.init(data, (num_index + 1), embedding_size);
data += mat1_size;
fForwardW.init(data, embedding_size, 4 * hunits);
data += mat2_size;
fForwardU.init(data, hunits, 4 * hunits);
data += mat3_size;
fForwardB.init(data, 4 * hunits);
data += mat4_size;
fBackwardW.init(data, embedding_size, 4 * hunits);
data += mat5_size;
fBackwardU.init(data, hunits, 4 * hunits);
data += mat6_size;
fBackwardB.init(data, 4 * hunits);
data += mat7_size;
fOutputW.init(data, 2 * hunits, 4);
data += mat8_size;
fOutputB.init(data, 4);
}
LSTMData::~LSTMData() {
uhash_close(fDict);
ures_close(fBundle);
}
class Vectorizer : public UMemory {
public:
Vectorizer(UHashtable* dict) : fDict(dict) {}
virtual ~Vectorizer();
virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
UVector32 &offsets, UVector32 &indices,
UErrorCode &status) const = 0;
protected:
int32_t stringToIndex(const char16_t* str) const {
2021-10-28 08:15:28 +02:00
UBool found = false;
int32_t ret = uhash_getiAndFound(fDict, (const void*)str, &found);
if (!found) {
ret = fDict->count;
}
#ifdef LSTM_VECTORIZER_DEBUG
printf("[");
while (*str != 0x0000) {
printf("U+%04x ", *str);
str++;
}
printf("] map to %d\n", ret);
#endif
return ret;
}
private:
UHashtable* fDict;
};
Vectorizer::~Vectorizer()
{
}
class CodePointsVectorizer : public Vectorizer {
public:
CodePointsVectorizer(UHashtable* dict) : Vectorizer(dict) {}
virtual ~CodePointsVectorizer();
virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
UVector32 &offsets, UVector32 &indices,
UErrorCode &status) const override;
};
CodePointsVectorizer::~CodePointsVectorizer()
{
}
void CodePointsVectorizer::vectorize(
UText *text, int32_t startPos, int32_t endPos,
UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
{
if (offsets.ensureCapacity(endPos - startPos, status) &&
indices.ensureCapacity(endPos - startPos, status)) {
if (U_FAILURE(status)) return;
utext_setNativeIndex(text, startPos);
int32_t current;
char16_t str[2] = {0, 0};
2021-10-28 08:15:28 +02:00
while (U_SUCCESS(status) &&
(current = (int32_t)utext_getNativeIndex(text)) < endPos) {
// Since the LSTMBreakEngine is currently only accept chars in BMP,
// we can ignore the possibility of hitting supplementary code
// point.
str[0] = (char16_t) utext_next32(text);
2021-10-28 08:15:28 +02:00
U_ASSERT(!U_IS_SURROGATE(str[0]));
offsets.addElement(current, status);
indices.addElement(stringToIndex(str), status);
}
}
}
class GraphemeClusterVectorizer : public Vectorizer {
public:
GraphemeClusterVectorizer(UHashtable* dict)
: Vectorizer(dict)
{
}
virtual ~GraphemeClusterVectorizer();
virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
UVector32 &offsets, UVector32 &indices,
UErrorCode &status) const override;
};
GraphemeClusterVectorizer::~GraphemeClusterVectorizer()
{
}
constexpr int32_t MAX_GRAPHEME_CLSTER_LENGTH = 10;
void GraphemeClusterVectorizer::vectorize(
UText *text, int32_t startPos, int32_t endPos,
UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
{
if (U_FAILURE(status)) return;
if (!offsets.ensureCapacity(endPos - startPos, status) ||
!indices.ensureCapacity(endPos - startPos, status)) {
return;
}
if (U_FAILURE(status)) return;
LocalPointer<BreakIterator> graphemeIter(BreakIterator::createCharacterInstance(Locale(), status));
if (U_FAILURE(status)) return;
graphemeIter->setText(text, status);
if (U_FAILURE(status)) return;
if (startPos != 0) {
graphemeIter->preceding(startPos);
}
int32_t last = startPos;
int32_t current = startPos;
char16_t str[MAX_GRAPHEME_CLSTER_LENGTH];
2021-10-28 08:15:28 +02:00
while ((current = graphemeIter->next()) != BreakIterator::DONE) {
if (current >= endPos) {
break;
}
if (current > startPos) {
utext_extract(text, last, current, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);
if (U_FAILURE(status)) return;
offsets.addElement(last, status);
indices.addElement(stringToIndex(str), status);
if (U_FAILURE(status)) return;
}
last = current;
}
if (U_FAILURE(status) || last >= endPos) {
return;
}
utext_extract(text, last, endPos, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);
if (U_SUCCESS(status)) {
offsets.addElement(last, status);
indices.addElement(stringToIndex(str), status);
}
}
// Computing LSTM as stated in
// https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate
// ifco is temp array allocate outside which does not need to be
// input/output value but could avoid unnecessary memory alloc/free if passing
// in.
void compute(
int32_t hunits,
const ReadArray2D& W, const ReadArray2D& U, const ReadArray1D& b,
const ReadArray1D& x, Array1D& h, Array1D& c,
Array1D& ifco)
{
// ifco = x * W + h * U + b
ifco.assign(b)
.addDotProduct(x, W)
.addDotProduct(h, U);
ifco.slice(0*hunits, hunits).sigmoid(); // i: sigmod
ifco.slice(1*hunits, hunits).sigmoid(); // f: sigmoid
ifco.slice(2*hunits, hunits).tanh(); // c_: tanh
ifco.slice(3*hunits, hunits).sigmoid(); // o: sigmod
c.hadamardProduct(ifco.slice(hunits, hunits))
.addHadamardProduct(ifco.slice(0, hunits), ifco.slice(2*hunits, hunits));
h.tanh(c)
.hadamardProduct(ifco.slice(3*hunits, hunits));
}
// Minimum word size
static const int32_t MIN_WORD = 2;
// Minimum number of characters for two words
static const int32_t MIN_WORD_SPAN = MIN_WORD * 2;
int32_t
LSTMBreakEngine::divideUpDictionaryRange( UText *text,
int32_t startPos,
int32_t endPos,
UVector32 &foundBreaks,
UBool /* isPhraseBreaking */,
2021-10-28 08:15:28 +02:00
UErrorCode& status) const {
if (U_FAILURE(status)) return 0;
int32_t beginFoundBreakSize = foundBreaks.size();
utext_setNativeIndex(text, startPos);
utext_moveIndex32(text, MIN_WORD_SPAN);
if (utext_getNativeIndex(text) >= endPos) {
return 0; // Not enough characters for two words
}
utext_setNativeIndex(text, startPos);
UVector32 offsets(status);
UVector32 indices(status);
if (U_FAILURE(status)) return 0;
fVectorizer->vectorize(text, startPos, endPos, offsets, indices, status);
if (U_FAILURE(status)) return 0;
int32_t* offsetsBuf = offsets.getBuffer();
int32_t* indicesBuf = indices.getBuffer();
int32_t input_seq_len = indices.size();
int32_t hunits = fData->fForwardU.d1();
// ----- Begin of all the Array memory allocation needed for this function
// Allocate temp array used inside compute()
Array1D ifco(4 * hunits, status);
Array1D c(hunits, status);
Array1D logp(4, status);
// TODO: limit size of hBackward. If input_seq_len is too big, we could
// run out of memory.
// Backward LSTM
Array2D hBackward(input_seq_len, hunits, status);
// Allocate fbRow and slice the internal array in two.
Array1D fbRow(2 * hunits, status);
// ----- End of all the Array memory allocation needed for this function
if (U_FAILURE(status)) return 0;
// To save the needed memory usage, the following is different from the
// Python or ICU4X implementation. We first perform the Backward LSTM
// and then merge the iteration of the forward LSTM and the output layer
// together because we only neetdto remember the h[t-1] for Forward LSTM.
for (int32_t i = input_seq_len - 1; i >= 0; i--) {
Array1D hRow = hBackward.row(i);
if (i != input_seq_len - 1) {
hRow.assign(hBackward.row(i+1));
}
#ifdef LSTM_DEBUG
printf("hRow %d\n", i);
hRow.print();
printf("indicesBuf[%d] = %d\n", i, indicesBuf[i]);
printf("fData->fEmbedding.row(indicesBuf[%d]):\n", i);
fData->fEmbedding.row(indicesBuf[i]).print();
#endif // LSTM_DEBUG
compute(hunits,
fData->fBackwardW, fData->fBackwardU, fData->fBackwardB,
fData->fEmbedding.row(indicesBuf[i]),
hRow, c, ifco);
}
Array1D forwardRow = fbRow.slice(0, hunits); // point to first half of data in fbRow.
Array1D backwardRow = fbRow.slice(hunits, hunits); // point to second half of data n fbRow.
// The following iteration merge the forward LSTM and the output layer
// together.
c.clear(); // reuse c since it is the same size.
for (int32_t i = 0; i < input_seq_len; i++) {
#ifdef LSTM_DEBUG
printf("forwardRow %d\n", i);
forwardRow.print();
#endif // LSTM_DEBUG
// Forward LSTM
// Calculate the result into forwardRow, which point to the data in the first half
// of fbRow.
compute(hunits,
fData->fForwardW, fData->fForwardU, fData->fForwardB,
fData->fEmbedding.row(indicesBuf[i]),
forwardRow, c, ifco);
// assign the data from hBackward.row(i) to second half of fbRowa.
backwardRow.assign(hBackward.row(i));
logp.assign(fData->fOutputB).addDotProduct(fbRow, fData->fOutputW);
#ifdef LSTM_DEBUG
printf("backwardRow %d\n", i);
backwardRow.print();
printf("logp %d\n", i);
logp.print();
#endif // LSTM_DEBUG
// current = argmax(logp)
LSTMClass current = (LSTMClass)logp.maxIndex();
// BIES logic.
if (current == BEGIN || current == SINGLE) {
if (i != 0) {
foundBreaks.addElement(offsetsBuf[i], status);
if (U_FAILURE(status)) return 0;
}
}
}
return foundBreaks.size() - beginFoundBreakSize;
}
Vectorizer* createVectorizer(const LSTMData* data, UErrorCode &status) {
if (U_FAILURE(status)) {
return nullptr;
}
switch (data->fType) {
case CODE_POINTS:
return new CodePointsVectorizer(data->fDict);
break;
case GRAPHEME_CLUSTER:
return new GraphemeClusterVectorizer(data->fDict);
break;
default:
break;
}
UPRV_UNREACHABLE_EXIT;
}
LSTMBreakEngine::LSTMBreakEngine(const LSTMData* data, const UnicodeSet& set, UErrorCode &status)
: DictionaryBreakEngine(), fData(data), fVectorizer(createVectorizer(fData, status))
{
if (U_FAILURE(status)) {
fData = nullptr; // If failure, we should not delete fData in destructor because the caller will do so.
return;
}
setCharacters(set);
}
LSTMBreakEngine::~LSTMBreakEngine() {
delete fData;
delete fVectorizer;
}
const char16_t* LSTMBreakEngine::name() const {
2021-10-28 08:15:28 +02:00
return fData->fName;
}
UnicodeString defaultLSTM(UScriptCode script, UErrorCode& status) {
// open root from brkitr tree.
UResourceBundle *b = ures_open(U_ICUDATA_BRKITR, "", &status);
b = ures_getByKeyWithFallback(b, "lstm", b, &status);
UnicodeString result = ures_getUnicodeStringByKey(b, uscript_getShortName(script), &status);
ures_close(b);
return result;
}
U_CAPI const LSTMData* U_EXPORT2 CreateLSTMDataForScript(UScriptCode script, UErrorCode& status)
{
if (script != USCRIPT_KHMER && script != USCRIPT_LAO && script != USCRIPT_MYANMAR && script != USCRIPT_THAI) {
return nullptr;
}
UnicodeString name = defaultLSTM(script, status);
if (U_FAILURE(status)) return nullptr;
CharString namebuf;
namebuf.appendInvariantChars(name, status).truncate(namebuf.lastIndexOf('.'));
LocalUResourceBundlePointer rb(
ures_openDirect(U_ICUDATA_BRKITR, namebuf.data(), &status));
if (U_FAILURE(status)) return nullptr;
return CreateLSTMData(rb.orphan(), status);
}
U_CAPI const LSTMData* U_EXPORT2 CreateLSTMData(UResourceBundle* rb, UErrorCode& status)
{
return new LSTMData(rb, status);
}
U_CAPI const LanguageBreakEngine* U_EXPORT2
CreateLSTMBreakEngine(UScriptCode script, const LSTMData* data, UErrorCode& status)
{
UnicodeString unicodeSetString;
switch(script) {
case USCRIPT_THAI:
unicodeSetString = UnicodeString(u"[[:Thai:]&[:LineBreak=SA:]]");
break;
case USCRIPT_MYANMAR:
unicodeSetString = UnicodeString(u"[[:Mymr:]&[:LineBreak=SA:]]");
break;
default:
delete data;
return nullptr;
}
UnicodeSet unicodeSet;
unicodeSet.applyPattern(unicodeSetString, status);
const LanguageBreakEngine* engine = new LSTMBreakEngine(data, unicodeSet, status);
if (U_FAILURE(status) || engine == nullptr) {
if (engine != nullptr) {
delete engine;
} else {
status = U_MEMORY_ALLOCATION_ERROR;
}
return nullptr;
}
return engine;
}
U_CAPI void U_EXPORT2 DeleteLSTMData(const LSTMData* data)
{
delete data;
}
U_CAPI const char16_t* U_EXPORT2 LSTMDataName(const LSTMData* data)
2021-10-28 08:15:28 +02:00
{
return data->fName;
}
U_NAMESPACE_END
#endif /* #if !UCONFIG_NO_BREAK_ITERATION */