diff --git a/opennlp-wsd/.project b/opennlp-wsd/.project deleted file mode 100644 index a15cd737..00000000 --- a/opennlp-wsd/.project +++ /dev/null @@ -1,11 +0,0 @@ - - - opennlp-wsd - - - - - - - - diff --git a/opennlp-wsd/src/main/java/opennlp/tools/cmdline/disambiguator/DisambiguatorTool.java b/opennlp-wsd/src/main/java/opennlp/tools/cmdline/disambiguator/DisambiguatorTool.java index 54767e67..4a89b856 100644 --- a/opennlp-wsd/src/main/java/opennlp/tools/cmdline/disambiguator/DisambiguatorTool.java +++ b/opennlp-wsd/src/main/java/opennlp/tools/cmdline/disambiguator/DisambiguatorTool.java @@ -108,7 +108,7 @@ public static Disambiguator makeTool(DisambiguatorToolParams params) { wsd = new Lesk(); } else if (params.getType().equalsIgnoreCase("ims")) { // TODO Set a "default" model for ENG -> future!? - wsd = new WSDisambiguatorME(null, new WSDDefaultParameters()); + wsd = new WSDisambiguatorME(null, WSDDefaultParameters.defaultParams()); } return wsd; diff --git a/opennlp-wsd/src/main/java/opennlp/tools/disambiguator/WSDDefaultParameters.java b/opennlp-wsd/src/main/java/opennlp/tools/disambiguator/WSDDefaultParameters.java index 5fce6dfd..5e271548 100644 --- a/opennlp-wsd/src/main/java/opennlp/tools/disambiguator/WSDDefaultParameters.java +++ b/opennlp-wsd/src/main/java/opennlp/tools/disambiguator/WSDDefaultParameters.java @@ -19,8 +19,14 @@ package opennlp.tools.disambiguator; -import java.io.File; -import java.nio.file.Path; +import java.io.IOException; +import java.io.InputStream; +import java.util.Map; +import java.util.Properties; +import java.util.TreeMap; + +import opennlp.tools.cmdline.CmdLineUtil; +import opennlp.tools.commons.Internal; /** * Defines the parameters for the @@ -31,88 +37,476 @@ */ public class WSDDefaultParameters extends WSDParameters { - public static final int DFLT_WIN_SIZE = 3; - public static final int DFLT_NGRAM = 2; - public static final String DFLT_LANG_CODE = "en"; - public static final SenseSource DFLT_SOURCE = SenseSource.WORDNET; - - private final Path trainingDataDir; - - private final String languageCode; - protected int windowSize; - protected int ngram; - - /** - * Initializes a new set of {@link WSDDefaultParameters}. - * The default language used is 'en' (English). - * - * @param windowSize The size of the window used for the extraction of the features - * qualified of Surrounding Words. - * @param ngram The number words used for the extraction of features qualified of - * Local Collocations. - * @param senseSource The {@link SenseSource source} of the training data - * @param trainingDataDir The {@link Path} where to store or read trained models from. - */ - public WSDDefaultParameters(int windowSize, int ngram, SenseSource senseSource, Path trainingDataDir) { - this.languageCode = DFLT_LANG_CODE; - this.windowSize = windowSize; - this.ngram = ngram; - this.senseSource = senseSource; - this.trainingDataDir = trainingDataDir; - if (trainingDataDir != null) { - File folder = trainingDataDir.toFile(); - if (!folder.exists()) - folder.mkdirs(); + public static final String WINDOW_SIZE_PARAM = "WindowSize"; + public static final String NGRAM_PARAM = "NGram"; + public static final String LANG_CODE = "LangCode"; + public static final String SENSE_SOURCE_PARAM = "SenseSource"; + public static final String TRAINING_DIR_PARAM = "TrainingDirectory"; + + /** + * The default window size is 3. + */ + public static final int WINDOW_SIZE_DEFAULT = 3; + + /** + * The default ngram width is 2. + */ + public static final int NGRAM_DEFAULT = 2; + + /** + * The default ISO language code is 'en'. + */ + public static final String LANG_CODE_DEFAULT = "en"; + + /** + * The default SenseSource is 'WORDNET'. + */ + public static final SenseSource SOURCE_DEFAULT = SenseSource.WORDNET; + + private final Map parameters = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); + + /** + * No-arg constructor to create a basic {@link WSDDefaultParameters} instance. + */ + @Internal + WSDDefaultParameters() { + } + + /** + * Key-value based constructor to apply a {@link Map} based configuration initialization. + */ + public WSDDefaultParameters(Map map) { + parameters.putAll(map); + } + + /** + * {@link InputStream} based constructor that reads in {@link WSDDefaultParameters}. + * + * @param in The {@link InputStream} to a kay-value based file that defines {@link WSDParameters}. + * @throws IOException Thrown if IO errors occurred. + */ + public WSDDefaultParameters(InputStream in) throws IOException { + final Properties properties = new Properties(); + properties.load(in); + + for (Map.Entry entry : properties.entrySet()) { + parameters.put((String) entry.getKey(), entry.getValue()); } } /** - * Initializes a new set of {@link WSDDefaultParameters}. - * The default language used is 'en' (English), the window size is {@link #DFLT_WIN_SIZE}, - * and the ngram length is initialized as {@link #DFLT_NGRAM}. + * {@inheritDoc} + */ + @Override + public boolean areValid() { + return true; + } + + /** + * Puts a {@code value} into the current {@link WSDParameters} under a certain {@code key}, + * if the value was not present before. + * The {@code namespace} can be used to prefix the {@code key}. + * + * @param namespace A prefix to declare or use a name space under which {@code key} shall be put. + * May be {@code null}. + * @param key The identifying key to put or retrieve a {@code value} with. + * @param value The {@link String} parameter to put into this {@link WSDParameters} instance. + */ + public void putIfAbsent(String namespace, String key, String value) { + parameters.putIfAbsent(getKey(namespace, key), value); + } + + /** + * Puts a {@code value} into the current {@link WSDParameters} under a certain {@code key}, + * if the value was not present before. * - * @implNote The training directory will be unset. + * @param key The identifying key to put or retrieve a {@code value} with. + * @param value The {@link String} parameter to put into this {@link WSDParameters} instance. */ - public WSDDefaultParameters() { - this(DFLT_WIN_SIZE, DFLT_NGRAM, DFLT_SOURCE, null); + public void putIfAbsent(String key, String value) { + putIfAbsent(null, key, value); } /** - * Initializes a new set of {@link WSDDefaultParameters}. - * The default language used is 'en' (English), the window size is {@link #DFLT_WIN_SIZE}, - * and the ngram length is initialized as {@link #DFLT_NGRAM}. + * Puts a {@code value} into the current {@link WSDParameters} under a certain {@code key}, + * if the value was not present before. + * The {@code namespace} can be used to prefix the {@code key}. * - * @param trainingDataDir The {@link Path} where to place or lookup trained models. + * @param namespace A prefix to declare or use a name space under which {@code key} shall be put. + * May be {@code null}. + * @param key The identifying key to put or retrieve a {@code value} with. + * @param value The {@link Integer} parameter to put into this {@link WSDParameters} instance. */ - public WSDDefaultParameters(Path trainingDataDir) { - this(DFLT_WIN_SIZE, DFLT_NGRAM, DFLT_SOURCE, trainingDataDir); + public void putIfAbsent(String namespace, String key, int value) { + parameters.putIfAbsent(getKey(namespace, key), value); } - public String getLanguageCode() { - return languageCode; + /** + * Puts a {@code value} into the current {@link WSDParameters} under a certain {@code key}, + * if the value was not present before. + * + * @param key The identifying key to put or retrieve a {@code value} with. + * @param value The {@link Integer} parameter to put into this {@link WSDParameters} instance. + */ + public void putIfAbsent(String key, int value) { + putIfAbsent(null, key, value); + } + + /** + * Puts a {@code value} into the current {@link WSDParameters} under a certain {@code key}, + * if the value was not present before. + * The {@code namespace} can be used to prefix the {@code key}. + * + * @param namespace A prefix to declare or use a name space under which {@code key} shall be put. + * May be {@code null}. + * @param key The identifying key to put or retrieve a {@code value} with. + * @param value The {@link Double} parameter to put into this {@link WSDParameters} instance. + */ + public void putIfAbsent(String namespace, String key, double value) { + parameters.putIfAbsent(getKey(namespace, key), value); } - public int getWindowSize() { - return windowSize; + /** + * Puts a {@code value} into the current {@link WSDParameters} under a certain {@code key}, + * if the value was not present before. + * The {@code namespace} can be used to prefix the {@code key}. + * + * @param key The identifying key to put or retrieve a {@code value} with. + * @param value The {@link Double} parameter to put into this {@link WSDParameters} instance. + */ + public void putIfAbsent(String key, double value) { + putIfAbsent(null, key, value); } - public int getNgram() { - return ngram; + /** + * Puts a {@code value} into the current {@link WSDParameters} under a certain {@code key}, + * if the value was not present before. + * The {@code namespace} can be used to prefix the {@code key}. + * + * @param namespace A prefix to declare or use a name space under which {@code key} shall be put. + * May be {@code null}. + * @param key The identifying key to put or retrieve a {@code value} with. + * @param value The {@link Boolean} parameter to put into this {@link WSDParameters} instance. + */ + public void putIfAbsent(String namespace, String key, boolean value) { + parameters.putIfAbsent(getKey(namespace, key), value); } /** - * @return The {@link Path} where to place or lookup trained models. May be {@code null}! + * Puts a {@code value} into the current {@link WSDParameters} under a certain {@code key}, + * if the value was not present before. + * + * @param key The identifying key to put or retrieve a {@code value} with. + * @param value The {@link Boolean} parameter to put into this {@link WSDParameters} instance. */ - public Path getTrainingDataDirectory() { - return trainingDataDir; + public void putIfAbsent(String key, boolean value) { + putIfAbsent(null, key, value); } /** - * {@inheritDoc} + * Puts a {@code value} into the current {@link WSDParameters} under a certain {@code key}. + * If the value was present before, the previous value will be overwritten with the specified one. + * The {@code namespace} can be used to prefix the {@code key}. + * + * @param namespace A prefix to declare or use a name space under which {@code key} shall be put. + * May be {@code null}. + * @param key The identifying key to put or retrieve a {@code value} with. + * @param value The {@link String} parameter to put into this {@link WSDParameters} instance. */ - @Override - public boolean areValid() { - return true; + public void put(String namespace, String key, String value) { + parameters.put(getKey(namespace, key), value); + } + + /** + * Puts a {@code value} into the current {@link WSDParameters} under a certain {@code key}. + * If the value was present before, the previous value will be overwritten with the specified one. + * + * @param key The identifying key to put or retrieve a {@code value} with. + * @param value The {@link String} parameter to put into this {@link WSDParameters} instance. + */ + public void put(String key, String value) { + put(null, key, value); + } + + /** + * Puts a {@code value} into the current {@link WSDParameters} under a certain {@code key}. + * If the value was present before, the previous value will be overwritten with the specified one. + * The {@code namespace} can be used to prefix the {@code key}. + * + * @param namespace A prefix to declare or use a name space under which {@code key} shall be put. + * May be {@code null}. + * @param key The identifying key to put or retrieve a {@code value} with. + * @param value The {@link Integer} parameter to put into this {@link WSDParameters} instance. + */ + public void put(String namespace, String key, int value) { + parameters.put(getKey(namespace, key), value); + } + + /** + * Puts a {@code value} into the current {@link WSDParameters} under a certain {@code key}. + * If the value was present before, the previous value will be overwritten with the specified one. + * + * @param key The identifying key to put or retrieve a {@code value} with. + * @param value The {@link Integer} parameter to put into this {@link WSDParameters} instance. + */ + public void put(String key, int value) { + put(null, key, value); + } + + /** + * Puts a {@code value} into the current {@link WSDParameters} under a certain {@code key}. + * If the value was present before, the previous value will be overwritten with the specified one. + * The {@code namespace} can be used to prefix the {@code key}. + * + * @param namespace A prefix to declare or use a name space under which {@code key} shall be put. + * May be {@code null}. + * @param key The identifying key to put or retrieve a {@code value} with. + * @param value The {@link Double} parameter to put into this {@link WSDParameters} instance. + */ + public void put(String namespace, String key, double value) { + parameters.put(getKey(namespace, key), value); } + /** + * Puts a {@code value} into the current {@link WSDParameters} under a certain {@code key}. + * If the value was present before, the previous value will be overwritten with the specified one. + * + * @param key The identifying key to put or retrieve a {@code value} with. + * @param value The {@link Double} parameter to put into this {@link WSDParameters} instance. + */ + public void put(String key, double value) { + put(null, key, value); + } + + /** + * Puts a {@code value} into the current {@link WSDParameters} under a certain {@code key}. + * If the value was present before, the previous value will be overwritten with the specified one. + * The {@code namespace} can be used to prefix the {@code key}. + * + * @param namespace A prefix to declare or use a name space under which {@code key} shall be put. + * May be {@code null}. + * @param key The identifying key to put or retrieve a {@code value} with. + * @param value The {@link Boolean} parameter to put into this {@link WSDParameters} instance. + */ + public void put(String namespace, String key, boolean value) { + parameters.put(getKey(namespace, key), value); + } + + /** + * Puts a {@code value} into the current {@link WSDParameters} under a certain {@code key}. + * If the value was present before, the previous value will be overwritten with the specified one. + * + * @param key The identifying key to put or retrieve a {@code value} with. + * @param value The {@link Boolean} parameter to put into this {@link WSDParameters} instance. + */ + public void put(String key, boolean value) { + put(null, key, value); + } + + /** + * Obtains a training parameter value. + * + * Note: + * {@link java.lang.ClassCastException} can be thrown if the value is not {@code String} + * + * @param key The identifying key to retrieve a {@code value} with. + * @param defaultValue The alternative value to use, if {@code key} was not present. + * @return The {@link String training value} associated with {@code key} if present, + * or a {@code defaultValue} if not. + */ + public String getStringParameter(String key, String defaultValue) { + return getStringParameter(null, key, defaultValue); + } + + /** + * Obtains a training parameter value in the specified namespace. + * + * Note: + * {@link java.lang.ClassCastException} can be thrown if the value is not {@link String} + * @param namespace A prefix to declare or use a name space under which {@code key} shall be searched. + * May be {@code null}. + * @param key The identifying key to retrieve a {@code value} with. + * @param defaultValue The alternative value to use, if {@code key} was not present. + * + * @return The {@link String training value} associated with {@code key} if present, + * or a {@code defaultValue} if not. + */ + public String getStringParameter(String namespace, String key, String defaultValue) { + Object value = parameters.get(getKey(namespace, key)); + if (value == null) { + return defaultValue; + } + else { + return (String)value; + } + } + + /** + * Obtains a training parameter value. + * + * + * @param key The identifying key to retrieve a {@code value} with. + * @param defaultValue The alternative value to use, if {@code key} was not present. + * @return The {@link Integer training value} associated with {@code key} if present, + * or a {@code defaultValue} if not. + */ + public int getIntParameter(String key, int defaultValue) { + return getIntParameter(null, key, defaultValue); + } + + /** + * Obtains a training parameter value in the specified namespace. + * + * @param namespace A prefix to declare or use a name space under which {@code key} shall be searched. + * May be {@code null}. + * @param key The identifying key to retrieve a {@code value} with. + * @param defaultValue The alternative value to use, if {@code key} was not present. + * + * @return The {@link Integer training value} associated with {@code key} if present, + * or a {@code defaultValue} if not. + */ + public int getIntParameter(String namespace, String key, int defaultValue) { + Object value = parameters.get(getKey(namespace, key)); + if (value == null) { + return defaultValue; + } + else { + try { + return (Integer) value; + } + catch (ClassCastException e) { + return Integer.parseInt((String)value); + } + } + } + + /** + * Obtains a training parameter value. + * + * + * @param key The identifying key to retrieve a {@code value} with. + * @param defaultValue The alternative value to use, if {@code key} was not present. + * @return The {@link Double training value} associated with {@code key} if present, + * or a {@code defaultValue} if not. + */ + public double getDoubleParameter(String key, double defaultValue) { + return getDoubleParameter(null, key, defaultValue); + } + + /** + * Obtains a training parameter value in the specified namespace. + * + * @param namespace A prefix to declare or use a name space under which {@code key} shall be searched. + * May be {@code null}. + * @param key The identifying key to retrieve a {@code value} with. + * @param defaultValue The alternative value to use, if {@code key} was not present. + * + * @return The {@link Double training value} associated with {@code key} if present, + * or a {@code defaultValue} if not. + */ + public double getDoubleParameter(String namespace, String key, double defaultValue) { + Object value = parameters.get(getKey(namespace, key)); + if (value == null) { + return defaultValue; + } + else { + try { + return (Double) value; + } + catch (ClassCastException e) { + return Double.parseDouble((String)value); + } + } + } + + /** + * Obtains a training parameter value. + * + * + * @param key The identifying key to retrieve a {@code value} with. + * @param defaultValue The alternative value to use, if {@code key} was not present. + * @return The {@link Boolean training value} associated with {@code key} if present, + * or a {@code defaultValue} if not. + */ + public boolean getBooleanParameter(String key, boolean defaultValue) { + return getBooleanParameter(null, key, defaultValue); + } + + /** + * Obtains a training parameter value in the specified namespace. + * + * @param namespace A prefix to declare or use a name space under which {@code key} shall be searched. + * May be {@code null}. + * @param key The identifying key to retrieve a {@code value} with. + * @param defaultValue The alternative value to use, if {@code key} was not present. + * + * @return The {@link Boolean training value} associated with {@code key} if present, + * or a {@code defaultValue} if not. + */ + public boolean getBooleanParameter(String namespace, String key, boolean defaultValue) { + Object value = parameters.get(getKey(namespace, key)); + if (value == null) { + return defaultValue; + } + else { + try { + return (Boolean) value; + } + catch (ClassCastException e) { + return Boolean.parseBoolean((String)value); + } + } + } + + /** + * @return Retrieves a new {@link WSDDefaultParameters instance} initialized with default values. + */ + public static WSDDefaultParameters defaultParams() { + WSDDefaultParameters wsdParams = new WSDDefaultParameters(); + wsdParams.put(WSDDefaultParameters.LANG_CODE, LANG_CODE_DEFAULT); + wsdParams.put(WSDDefaultParameters.WINDOW_SIZE_PARAM, WINDOW_SIZE_DEFAULT); + wsdParams.put(WSDDefaultParameters.NGRAM_PARAM, NGRAM_DEFAULT); + wsdParams.put(WSDDefaultParameters.SENSE_SOURCE_PARAM, SOURCE_DEFAULT.name()); + return wsdParams; + } + + /** + * @param params The parameters to additionally apply into the new {@link WSDDefaultParameters instance}. + * + * @return Retrieves a new {@link WSDDefaultParameters instance} initialized with given parameter values. + */ + public static WSDDefaultParameters setParams(String[] params) { + WSDDefaultParameters wsdParams = new WSDDefaultParameters(); + wsdParams.put(WSDDefaultParameters.LANG_CODE, LANG_CODE_DEFAULT); + wsdParams.put(WSDDefaultParameters.SENSE_SOURCE_PARAM, SOURCE_DEFAULT.name()); + wsdParams.put(WSDDefaultParameters.WINDOW_SIZE_PARAM , + null != CmdLineUtil.getIntParameter("-" + + WSDDefaultParameters.WINDOW_SIZE_PARAM.toLowerCase() , params) ? + CmdLineUtil.getIntParameter("-" + WSDDefaultParameters.WINDOW_SIZE_PARAM.toLowerCase() , params) : + WINDOW_SIZE_DEFAULT); + wsdParams.put(WSDDefaultParameters.NGRAM_PARAM , + null != CmdLineUtil.getIntParameter("-" + + WSDDefaultParameters.NGRAM_PARAM.toLowerCase() , params) ? + CmdLineUtil.getIntParameter("-" + WSDDefaultParameters.NGRAM_PARAM.toLowerCase() , params) : + NGRAM_DEFAULT); + + return wsdParams; + } + + /** + * @param namespace The namespace used as prefix or {@code null}. + * If {@code null} the {@code key} is left unchanged. + * @param key The identifying key to process. + * + * @return Retrieves a prefixed key in the specified {@code namespace}. + * If no {@code namespace} was specified the returned String is equal to {@code key}. + */ + static String getKey(String namespace, String key) { + if (namespace == null) { + return key; + } + else { + return namespace + "." + key; + } + } } diff --git a/opennlp-wsd/src/main/java/opennlp/tools/disambiguator/WSDModel.java b/opennlp-wsd/src/main/java/opennlp/tools/disambiguator/WSDModel.java index 87bc3d2c..299d6603 100644 --- a/opennlp-wsd/src/main/java/opennlp/tools/disambiguator/WSDModel.java +++ b/opennlp-wsd/src/main/java/opennlp/tools/disambiguator/WSDModel.java @@ -45,7 +45,7 @@ public class WSDModel extends BaseModel { @Serial - private static final long serialVersionUID = 2961852011373749729L; + private static final long serialVersionUID = -5191919760365361954L; private static final String COMPONENT_NAME = "WSD"; private static final String WSD_MODEL_ENTRY = "WSD.model"; @@ -159,6 +159,13 @@ public String getWordTag() { return getManifestProperty(WORDTAG); } + /** + * @return Retrieves the active {@link WSDisambiguatorFactory}. + */ + public WSDisambiguatorFactory getWSDFactory() { + return (WSDisambiguatorFactory) this.toolFactory; + } + /** * {@inheritDoc} */ diff --git a/opennlp-wsd/src/main/java/opennlp/tools/disambiguator/WSDisambiguatorME.java b/opennlp-wsd/src/main/java/opennlp/tools/disambiguator/WSDisambiguatorME.java index 387699f7..024c4620 100644 --- a/opennlp-wsd/src/main/java/opennlp/tools/disambiguator/WSDisambiguatorME.java +++ b/opennlp-wsd/src/main/java/opennlp/tools/disambiguator/WSDisambiguatorME.java @@ -17,6 +17,13 @@ package opennlp.tools.disambiguator; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import opennlp.tools.ml.EventTrainer; import opennlp.tools.ml.TrainerFactory; import opennlp.tools.ml.model.Event; @@ -25,14 +32,6 @@ import opennlp.tools.util.ObjectStreamUtils; import opennlp.tools.util.TrainingParameters; -import java.io.File; -import java.io.IOException; -import java.security.InvalidParameterException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; - /** * A {@link Disambiguator} implementation based on a Maximum Entropy (ME) approach. * @@ -45,8 +44,7 @@ */ public class WSDisambiguatorME extends AbstractWSDisambiguator { - protected static final WSDContextGenerator CONTEXT_GENERATOR = new IMSWSDContextGenerator(); - + private final WSDContextGenerator cg; private final WSDModel model; /** @@ -60,8 +58,10 @@ public WSDisambiguatorME(WSDModel model, WSDParameters params) { if (model == null || params == null) { throw new IllegalArgumentException("Parameters cannot be null!"); } - this.model = model; super.params = params; + this.model = model; + WSDisambiguatorFactory factory = model.getWSDFactory(); + cg = factory.getContextGenerator(); } /** @@ -84,47 +84,45 @@ public WSDModel getModel() { * during training. Or if reading from the {@link ObjectStream} fails. */ public static WSDModel train(String lang, ObjectStream samples, - TrainingParameters mlParams, WSDParameters params) throws IOException { - - WSDDefaultParameters defParams = ((WSDDefaultParameters) params); - List surroundingContext = buildSurroundingContext(samples, defParams.getWindowSize()); - - HashMap manifestInfoEntries = new HashMap<>(); + TrainingParameters mlParams, WSDParameters params, + WSDisambiguatorFactory factory) throws IOException { - MaxentModel meModel; + final WSDDefaultParameters defParams = ((WSDDefaultParameters) params); + final int wSize = defParams.getIntParameter( + WSDDefaultParameters.WINDOW_SIZE_PARAM, WSDDefaultParameters.WINDOW_SIZE_DEFAULT); + final int ngram = defParams.getIntParameter( + WSDDefaultParameters.NGRAM_PARAM, WSDDefaultParameters.NGRAM_DEFAULT); + List surroundingContext = buildSurroundingContext(samples, wSize); List events = new ArrayList<>(); - ObjectStream es; - WSDSample sample = samples.read(); String wordTag = ""; if (sample != null) { + final WSDContextGenerator cg = factory.getContextGenerator(); wordTag = sample.getTargetWordTag(); do { String sense = sample.getSenseIDs()[0]; - String[] context = CONTEXT_GENERATOR.getContext(sample, - defParams.ngram, defParams.windowSize, surroundingContext); + String[] context = cg.getContext(sample, ngram, wSize, surroundingContext); Event ev = new Event(sense, context); events.add(ev); } while ((sample = samples.read()) != null); } - es = ObjectStreamUtils.createObjectStream(events); + final Map manifestInfoEntries = new HashMap<>(); + ObjectStream es = ObjectStreamUtils.createObjectStream(events); EventTrainer trainer = TrainerFactory.getEventTrainer(mlParams, manifestInfoEntries); + MaxentModel meModel = trainer.train(es); - meModel = trainer.train(es); - - return new WSDModel(lang, wordTag, defParams.windowSize, defParams.ngram, - meModel, surroundingContext, manifestInfoEntries); + return new WSDModel(lang, wordTag, wSize, ngram, meModel, surroundingContext, manifestInfoEntries); } private static List buildSurroundingContext(ObjectStream samples, int windowSize) throws IOException { - IMSWSDContextGenerator contextGenerator = new IMSWSDContextGenerator(); + IMSWSDContextGenerator cg = new IMSWSDContextGenerator(); List surroundingWordsModel = new ArrayList<>(); WSDSample sample; while ((sample = samples.read()) != null) { - String[] words = contextGenerator.extractSurroundingContext(sample.getTargetPosition(), + String[] words = cg.extractSurroundingContext(sample.getTargetPosition(), sample.getSentence(), sample.getLemmas(), windowSize); if (words.length > 0) { @@ -150,14 +148,17 @@ public String disambiguate(String[] tokenizedContext, String[] tokenTags, @Override public String disambiguate(WSDSample sample) { final WSDDefaultParameters defParams = ((WSDDefaultParameters) params); - final String wordTag = sample.getTargetWordTag(); + final int wSize = defParams.getIntParameter( + WSDDefaultParameters.WINDOW_SIZE_PARAM, WSDDefaultParameters.WINDOW_SIZE_DEFAULT); + final int ngram = defParams.getIntParameter( + WSDDefaultParameters.NGRAM_PARAM, WSDDefaultParameters.NGRAM_DEFAULT); + final String wordTag = sample.getTargetWordTag(); if (WSDHelper.isRelevantPOSTag(sample.getTargetTag())) { if (!model.getWordTag().equals(wordTag)) { return disambiguate(wordTag); } else { - String[] context = CONTEXT_GENERATOR.getContext(sample, - defParams.ngram, defParams.windowSize, this.model.getContextEntries()); + String[] context = cg.getContext(sample, wSize, ngram, this.model.getContextEntries()); double[] outcomeProbs = model.getWSDMaxentModel().eval(context); String outcome = model.getWSDMaxentModel().getBestOutcome(outcomeProbs); diff --git a/opennlp-wsd/src/test/java/opennlp/tools/cmdline/wsd/WSDModelLoaderTest.java b/opennlp-wsd/src/test/java/opennlp/tools/cmdline/wsd/WSDModelLoaderTest.java index d76a71ae..b7730204 100644 --- a/opennlp-wsd/src/test/java/opennlp/tools/cmdline/wsd/WSDModelLoaderTest.java +++ b/opennlp-wsd/src/test/java/opennlp/tools/cmdline/wsd/WSDModelLoaderTest.java @@ -36,6 +36,7 @@ import opennlp.tools.AbstractTest; import opennlp.tools.disambiguator.WSDDefaultParameters; +import opennlp.tools.disambiguator.WSDisambiguatorFactory; import opennlp.tools.disambiguator.WSDModel; import opennlp.tools.disambiguator.WSDSample; import opennlp.tools.disambiguator.WSDisambiguatorME; @@ -64,18 +65,25 @@ static void createSimpleWSDModel(@TempDir(cleanup = CleanupMode.ALWAYS) Path tmp Path workDir = tmpDir.resolve("models" + File.separatorChar); trainingDir = workDir.resolve("training" + File.separatorChar) .resolve("supervised" + File.separatorChar); + File folder = trainingDir.toFile(); + if (!folder.exists()) { + assertTrue(folder.mkdirs()); + } + final TrainingParameters params = TrainingParameters.defaultParams(); params.put(TrainingParameters.THREADS_PARAM, 4); - + final WSDDefaultParameters wsdParams = WSDDefaultParameters.defaultParams(); + wsdParams.putIfAbsent(WSDDefaultParameters.TRAINING_DIR_PARAM, trainingDir.toAbsolutePath().toString()); + + final WSDisambiguatorFactory factory = new WSDisambiguatorFactory(); final SemcorReaderExtended sr = new SemcorReaderExtended(SEMCOR_DIR); final ObjectStream samples = sr.getSemcorDataStream(WORD_TAG); try { - WSDDefaultParameters wsdParams = new WSDDefaultParameters(trainingDir); - trainedModel = WSDisambiguatorME.train("en", samples, params, wsdParams); + trainedModel = WSDisambiguatorME.train("en", samples, params, wsdParams, factory); assertNotNull(trainedModel); - File modelFile = new File(wsdParams.getTrainingDataDirectory() + - Character.toString(File.separatorChar) + WORD_TAG + ".wsd.model"); + File modelFile = new File(wsdParams.getStringParameter( + WSDDefaultParameters.TRAINING_DIR_PARAM, "") + File.separatorChar + WORD_TAG + ".wsd.model"); try (OutputStream modelOut = new BufferedOutputStream(new FileOutputStream(modelFile))) { trainedModel.serialize(modelOut); } diff --git a/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDDefaultParametersTest.java b/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDDefaultParametersTest.java index 691d3b33..8a019250 100644 --- a/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDDefaultParametersTest.java +++ b/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDDefaultParametersTest.java @@ -45,15 +45,20 @@ static void initEnv(@TempDir(cleanup = CleanupMode.ALWAYS) Path tmpDir) { @Test void testCreate() { - WSDDefaultParameters params = new WSDDefaultParameters(trainingDir); + WSDDefaultParameters params = WSDDefaultParameters.defaultParams(); + params.putIfAbsent(WSDDefaultParameters.TRAINING_DIR_PARAM, trainingDir.toAbsolutePath().toString()); assertNotNull(params); assertInstanceOf(WSDParameters.class, params); assertTrue(params.areValid()); - assertEquals(trainingDir, params.getTrainingDataDirectory()); - assertEquals(WSDDefaultParameters.DFLT_NGRAM, params.getNgram()); - assertEquals(WSDDefaultParameters.DFLT_WIN_SIZE, params.getWindowSize()); - assertEquals(WSDDefaultParameters.DFLT_LANG_CODE, params.getLanguageCode()); - assertEquals(WSDDefaultParameters.DFLT_SOURCE, params.getSenseSource()); + assertEquals(WSDDefaultParameters.NGRAM_DEFAULT, + params.getIntParameter(WSDDefaultParameters.NGRAM_PARAM, WSDDefaultParameters.NGRAM_DEFAULT)); + assertEquals(WSDDefaultParameters.WINDOW_SIZE_DEFAULT, + params.getIntParameter(WSDDefaultParameters.WINDOW_SIZE_PARAM, WSDDefaultParameters.WINDOW_SIZE_DEFAULT)); + assertEquals(WSDDefaultParameters.LANG_CODE_DEFAULT, + params.getStringParameter(WSDDefaultParameters.LANG_CODE, WSDDefaultParameters.LANG_CODE_DEFAULT)); + assertEquals(WSDDefaultParameters.SOURCE_DEFAULT, + WSDParameters.SenseSource.valueOf( + params.getStringParameter(WSDDefaultParameters.SENSE_SOURCE_PARAM, WSDDefaultParameters.SOURCE_DEFAULT.name()))); } } diff --git a/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDEvaluatorIT.java b/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDEvaluatorIT.java index e607325c..17085589 100644 --- a/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDEvaluatorIT.java +++ b/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDEvaluatorIT.java @@ -39,6 +39,7 @@ import opennlp.tools.util.TrainingParameters; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; class WSDEvaluatorIT extends AbstractEvaluatorTest { @@ -55,9 +56,15 @@ static void prepareEnv(@TempDir(cleanup = CleanupMode.ALWAYS) Path tmpDir) { Path workDir = tmpDir.resolve("models" + File.separatorChar); Path trainingDir = workDir.resolve("training" + File.separatorChar) .resolve("supervised" + File.separatorChar); - wsdParams = new WSDDefaultParameters(trainingDir); + File folder = trainingDir.toFile(); + if (!folder.exists()) { + assertTrue(folder.mkdirs()); + } + wsdParams = WSDDefaultParameters.defaultParams(); + wsdParams.putIfAbsent(WSDDefaultParameters.TRAINING_DIR_PARAM, trainingDir.toAbsolutePath().toString()); final SemcorReaderExtended seReader = new SemcorReaderExtended(SEMCOR_DIR); + final WSDisambiguatorFactory factory = new WSDisambiguatorFactory(); // train the models in parallel sampleTestWordMapping.keySet().parallelStream().forEach(word -> { @@ -71,12 +78,12 @@ static void prepareEnv(@TempDir(cleanup = CleanupMode.ALWAYS) Path tmpDir) { * file trained by semcor */ try { - final TrainingParameters trainingParams = TrainingParameters.defaultParams(); - trainingParams.put(TrainingParameters.THREADS_PARAM, 4); - WSDModel trained = WSDisambiguatorME.train("en", sampleStream, trainingParams, wsdParams); + final TrainingParameters params = TrainingParameters.defaultParams(); + params.put(TrainingParameters.THREADS_PARAM, 4); + WSDModel trained = WSDisambiguatorME.train("en", sampleStream, params, wsdParams, factory); assertNotNull(trained, "Checking the model to be written"); - File modelFile = new File(wsdParams.getTrainingDataDirectory() + - Character.toString(File.separatorChar) + word + ".wsd.model"); + File modelFile = new File(wsdParams.getStringParameter( + WSDDefaultParameters.TRAINING_DIR_PARAM, "") + File.separatorChar + word + ".wsd.model"); try (OutputStream modelOut = new BufferedOutputStream(new FileOutputStream(modelFile))) { trained.serialize(modelOut); } @@ -94,8 +101,9 @@ void testDisambiguationEval() { sampleTestWordMapping.keySet().parallelStream().forEach(word -> { // don't take verbs because they are not from WordNet if (!SPLIT.split(word)[1].equals("v")) { - File modelFile = new File(wsdParams.getTrainingDataDirectory() + - Character.toString(File.separatorChar) + word + ".wsd.model"); + File modelFile = new File(wsdParams.getStringParameter( + WSDDefaultParameters.TRAINING_DIR_PARAM, "") + + File.separatorChar + word + ".wsd.model"); WSDModel model = null; try { model = new WSDModel(modelFile); diff --git a/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDModelTest.java b/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDModelTest.java index 6e25ed1a..3c382b36 100644 --- a/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDModelTest.java +++ b/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDModelTest.java @@ -28,15 +28,15 @@ import java.io.OutputStream; import java.nio.file.Path; -import opennlp.tools.AbstractTest; -import opennlp.tools.ml.model.MaxentModel; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.CleanupMode; import org.junit.jupiter.api.io.TempDir; +import opennlp.tools.AbstractTest; import opennlp.tools.disambiguator.datareader.SemcorReaderExtended; import opennlp.tools.ml.maxent.GISModel; +import opennlp.tools.ml.model.MaxentModel; import opennlp.tools.util.ObjectStream; import opennlp.tools.util.TrainingParameters; @@ -61,20 +61,25 @@ class WSDModelTest extends AbstractTest { @BeforeAll static void createSimpleWSDModel(@TempDir(cleanup = CleanupMode.ALWAYS) Path tmpDir) { - Path workDir = tmpDir.resolve("models" + File.separatorChar); trainingDir = workDir.resolve("training" + File.separatorChar) .resolve("supervised" + File.separatorChar); + File folder = trainingDir.toFile(); + if (!folder.exists()) { + assertTrue(folder.mkdirs()); + } + final TrainingParameters params = TrainingParameters.defaultParams(); params.put(TrainingParameters.THREADS_PARAM, 4); + final WSDDefaultParameters wsdParams = WSDDefaultParameters.defaultParams(); + wsdParams.putIfAbsent(WSDDefaultParameters.TRAINING_DIR_PARAM, trainingDir.toAbsolutePath().toString()); + final WSDisambiguatorFactory factory = new WSDisambiguatorFactory(); final SemcorReaderExtended sr = new SemcorReaderExtended(SEMCOR_DIR); - final ObjectStream samples = sr.getSemcorDataStream(WORD_TAG); try { - trainedModel = WSDisambiguatorME.train("en", samples, params, - new WSDDefaultParameters(trainingDir)); + trainedModel = WSDisambiguatorME.train("en", samples, params, wsdParams, factory); assertNotNull(trainedModel); } catch (IOException e1) { fail("Exception in training: " + e1.getMessage()); diff --git a/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDisambiguatorMETest.java b/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDisambiguatorMETest.java index 09b84c79..877f48a1 100644 --- a/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDisambiguatorMETest.java +++ b/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDisambiguatorMETest.java @@ -37,6 +37,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; /** @@ -66,8 +67,15 @@ static void prepareEnv(@TempDir(cleanup = CleanupMode.ALWAYS) Path tmpDir) { Path workDir = tmpDir.resolve("models" + File.separatorChar); Path trainingDir = workDir.resolve("training" + File.separatorChar) .resolve("supervised" + File.separatorChar); - wsdParams = new WSDDefaultParameters(trainingDir); + File folder = trainingDir.toFile(); + if (!folder.exists()) { + assertTrue(folder.mkdirs()); + } + wsdParams = WSDDefaultParameters.defaultParams(); + wsdParams.putIfAbsent(WSDDefaultParameters.TRAINING_DIR_PARAM, trainingDir.toAbsolutePath().toString()); + final TrainingParameters trainingParams = TrainingParameters.defaultParams(); + final WSDisambiguatorFactory factory = new WSDisambiguatorFactory(); final SemcorReaderExtended sr = new SemcorReaderExtended(SEMCOR_DIR); final String test = "please.v"; @@ -78,7 +86,7 @@ static void prepareEnv(@TempDir(cleanup = CleanupMode.ALWAYS) Path tmpDir) { * We test both writing and reading a model file trained by semcor */ try { - model= WSDisambiguatorME.train("en", sampleStream, trainingParams, wsdParams); + model= WSDisambiguatorME.train("en", sampleStream, trainingParams, wsdParams, factory); assertNotNull(model, "Checking the model"); } catch (IOException e1) { fail("Exception in training: " + e1.getMessage());
+ * Note: + * {@link java.lang.ClassCastException} can be thrown if the value is not {@code String} + * + * @param key The identifying key to retrieve a {@code value} with. + * @param defaultValue The alternative value to use, if {@code key} was not present. + * @return The {@link String training value} associated with {@code key} if present, + * or a {@code defaultValue} if not. + */ + public String getStringParameter(String key, String defaultValue) { + return getStringParameter(null, key, defaultValue); + } + + /** + * Obtains a training parameter value in the specified namespace. + *
+ * Note: + * {@link java.lang.ClassCastException} can be thrown if the value is not {@link String} + * @param namespace A prefix to declare or use a name space under which {@code key} shall be searched. + * May be {@code null}. + * @param key The identifying key to retrieve a {@code value} with. + * @param defaultValue The alternative value to use, if {@code key} was not present. + * + * @return The {@link String training value} associated with {@code key} if present, + * or a {@code defaultValue} if not. + */ + public String getStringParameter(String namespace, String key, String defaultValue) { + Object value = parameters.get(getKey(namespace, key)); + if (value == null) { + return defaultValue; + } + else { + return (String)value; + } + } + + /** + * Obtains a training parameter value. + *
+ * + * @param key The identifying key to retrieve a {@code value} with. + * @param defaultValue The alternative value to use, if {@code key} was not present. + * @return The {@link Integer training value} associated with {@code key} if present, + * or a {@code defaultValue} if not. + */ + public int getIntParameter(String key, int defaultValue) { + return getIntParameter(null, key, defaultValue); + } + + /** + * Obtains a training parameter value in the specified namespace. + *
+ * @param namespace A prefix to declare or use a name space under which {@code key} shall be searched. + * May be {@code null}. + * @param key The identifying key to retrieve a {@code value} with. + * @param defaultValue The alternative value to use, if {@code key} was not present. + * + * @return The {@link Integer training value} associated with {@code key} if present, + * or a {@code defaultValue} if not. + */ + public int getIntParameter(String namespace, String key, int defaultValue) { + Object value = parameters.get(getKey(namespace, key)); + if (value == null) { + return defaultValue; + } + else { + try { + return (Integer) value; + } + catch (ClassCastException e) { + return Integer.parseInt((String)value); + } + } + } + + /** + * Obtains a training parameter value. + *
+ * + * @param key The identifying key to retrieve a {@code value} with. + * @param defaultValue The alternative value to use, if {@code key} was not present. + * @return The {@link Double training value} associated with {@code key} if present, + * or a {@code defaultValue} if not. + */ + public double getDoubleParameter(String key, double defaultValue) { + return getDoubleParameter(null, key, defaultValue); + } + + /** + * Obtains a training parameter value in the specified namespace. + *
+ * @param namespace A prefix to declare or use a name space under which {@code key} shall be searched. + * May be {@code null}. + * @param key The identifying key to retrieve a {@code value} with. + * @param defaultValue The alternative value to use, if {@code key} was not present. + * + * @return The {@link Double training value} associated with {@code key} if present, + * or a {@code defaultValue} if not. + */ + public double getDoubleParameter(String namespace, String key, double defaultValue) { + Object value = parameters.get(getKey(namespace, key)); + if (value == null) { + return defaultValue; + } + else { + try { + return (Double) value; + } + catch (ClassCastException e) { + return Double.parseDouble((String)value); + } + } + } + + /** + * Obtains a training parameter value. + *
+ * + * @param key The identifying key to retrieve a {@code value} with. + * @param defaultValue The alternative value to use, if {@code key} was not present. + * @return The {@link Boolean training value} associated with {@code key} if present, + * or a {@code defaultValue} if not. + */ + public boolean getBooleanParameter(String key, boolean defaultValue) { + return getBooleanParameter(null, key, defaultValue); + } + + /** + * Obtains a training parameter value in the specified namespace. + *
+ * @param namespace A prefix to declare or use a name space under which {@code key} shall be searched. + * May be {@code null}. + * @param key The identifying key to retrieve a {@code value} with. + * @param defaultValue The alternative value to use, if {@code key} was not present. + * + * @return The {@link Boolean training value} associated with {@code key} if present, + * or a {@code defaultValue} if not. + */ + public boolean getBooleanParameter(String namespace, String key, boolean defaultValue) { + Object value = parameters.get(getKey(namespace, key)); + if (value == null) { + return defaultValue; + } + else { + try { + return (Boolean) value; + } + catch (ClassCastException e) { + return Boolean.parseBoolean((String)value); + } + } + } + + /** + * @return Retrieves a new {@link WSDDefaultParameters instance} initialized with default values. + */ + public static WSDDefaultParameters defaultParams() { + WSDDefaultParameters wsdParams = new WSDDefaultParameters(); + wsdParams.put(WSDDefaultParameters.LANG_CODE, LANG_CODE_DEFAULT); + wsdParams.put(WSDDefaultParameters.WINDOW_SIZE_PARAM, WINDOW_SIZE_DEFAULT); + wsdParams.put(WSDDefaultParameters.NGRAM_PARAM, NGRAM_DEFAULT); + wsdParams.put(WSDDefaultParameters.SENSE_SOURCE_PARAM, SOURCE_DEFAULT.name()); + return wsdParams; + } + + /** + * @param params The parameters to additionally apply into the new {@link WSDDefaultParameters instance}. + * + * @return Retrieves a new {@link WSDDefaultParameters instance} initialized with given parameter values. + */ + public static WSDDefaultParameters setParams(String[] params) { + WSDDefaultParameters wsdParams = new WSDDefaultParameters(); + wsdParams.put(WSDDefaultParameters.LANG_CODE, LANG_CODE_DEFAULT); + wsdParams.put(WSDDefaultParameters.SENSE_SOURCE_PARAM, SOURCE_DEFAULT.name()); + wsdParams.put(WSDDefaultParameters.WINDOW_SIZE_PARAM , + null != CmdLineUtil.getIntParameter("-" + + WSDDefaultParameters.WINDOW_SIZE_PARAM.toLowerCase() , params) ? + CmdLineUtil.getIntParameter("-" + WSDDefaultParameters.WINDOW_SIZE_PARAM.toLowerCase() , params) : + WINDOW_SIZE_DEFAULT); + wsdParams.put(WSDDefaultParameters.NGRAM_PARAM , + null != CmdLineUtil.getIntParameter("-" + + WSDDefaultParameters.NGRAM_PARAM.toLowerCase() , params) ? + CmdLineUtil.getIntParameter("-" + WSDDefaultParameters.NGRAM_PARAM.toLowerCase() , params) : + NGRAM_DEFAULT); + + return wsdParams; + } + + /** + * @param namespace The namespace used as prefix or {@code null}. + * If {@code null} the {@code key} is left unchanged. + * @param key The identifying key to process. + * + * @return Retrieves a prefixed key in the specified {@code namespace}. + * If no {@code namespace} was specified the returned String is equal to {@code key}. + */ + static String getKey(String namespace, String key) { + if (namespace == null) { + return key; + } + else { + return namespace + "." + key; + } + } } diff --git a/opennlp-wsd/src/main/java/opennlp/tools/disambiguator/WSDModel.java b/opennlp-wsd/src/main/java/opennlp/tools/disambiguator/WSDModel.java index 87bc3d2c..299d6603 100644 --- a/opennlp-wsd/src/main/java/opennlp/tools/disambiguator/WSDModel.java +++ b/opennlp-wsd/src/main/java/opennlp/tools/disambiguator/WSDModel.java @@ -45,7 +45,7 @@ public class WSDModel extends BaseModel { @Serial - private static final long serialVersionUID = 2961852011373749729L; + private static final long serialVersionUID = -5191919760365361954L; private static final String COMPONENT_NAME = "WSD"; private static final String WSD_MODEL_ENTRY = "WSD.model"; @@ -159,6 +159,13 @@ public String getWordTag() { return getManifestProperty(WORDTAG); } + /** + * @return Retrieves the active {@link WSDisambiguatorFactory}. + */ + public WSDisambiguatorFactory getWSDFactory() { + return (WSDisambiguatorFactory) this.toolFactory; + } + /** * {@inheritDoc} */ diff --git a/opennlp-wsd/src/main/java/opennlp/tools/disambiguator/WSDisambiguatorME.java b/opennlp-wsd/src/main/java/opennlp/tools/disambiguator/WSDisambiguatorME.java index 387699f7..024c4620 100644 --- a/opennlp-wsd/src/main/java/opennlp/tools/disambiguator/WSDisambiguatorME.java +++ b/opennlp-wsd/src/main/java/opennlp/tools/disambiguator/WSDisambiguatorME.java @@ -17,6 +17,13 @@ package opennlp.tools.disambiguator; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import opennlp.tools.ml.EventTrainer; import opennlp.tools.ml.TrainerFactory; import opennlp.tools.ml.model.Event; @@ -25,14 +32,6 @@ import opennlp.tools.util.ObjectStreamUtils; import opennlp.tools.util.TrainingParameters; -import java.io.File; -import java.io.IOException; -import java.security.InvalidParameterException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; - /** * A {@link Disambiguator} implementation based on a Maximum Entropy (ME) approach. *
@@ -45,8 +44,7 @@ */ public class WSDisambiguatorME extends AbstractWSDisambiguator { - protected static final WSDContextGenerator CONTEXT_GENERATOR = new IMSWSDContextGenerator(); - + private final WSDContextGenerator cg; private final WSDModel model; /** @@ -60,8 +58,10 @@ public WSDisambiguatorME(WSDModel model, WSDParameters params) { if (model == null || params == null) { throw new IllegalArgumentException("Parameters cannot be null!"); } - this.model = model; super.params = params; + this.model = model; + WSDisambiguatorFactory factory = model.getWSDFactory(); + cg = factory.getContextGenerator(); } /** @@ -84,47 +84,45 @@ public WSDModel getModel() { * during training. Or if reading from the {@link ObjectStream} fails. */ public static WSDModel train(String lang, ObjectStream samples, - TrainingParameters mlParams, WSDParameters params) throws IOException { - - WSDDefaultParameters defParams = ((WSDDefaultParameters) params); - List surroundingContext = buildSurroundingContext(samples, defParams.getWindowSize()); - - HashMap manifestInfoEntries = new HashMap<>(); + TrainingParameters mlParams, WSDParameters params, + WSDisambiguatorFactory factory) throws IOException { - MaxentModel meModel; + final WSDDefaultParameters defParams = ((WSDDefaultParameters) params); + final int wSize = defParams.getIntParameter( + WSDDefaultParameters.WINDOW_SIZE_PARAM, WSDDefaultParameters.WINDOW_SIZE_DEFAULT); + final int ngram = defParams.getIntParameter( + WSDDefaultParameters.NGRAM_PARAM, WSDDefaultParameters.NGRAM_DEFAULT); + List surroundingContext = buildSurroundingContext(samples, wSize); List events = new ArrayList<>(); - ObjectStream es; - WSDSample sample = samples.read(); String wordTag = ""; if (sample != null) { + final WSDContextGenerator cg = factory.getContextGenerator(); wordTag = sample.getTargetWordTag(); do { String sense = sample.getSenseIDs()[0]; - String[] context = CONTEXT_GENERATOR.getContext(sample, - defParams.ngram, defParams.windowSize, surroundingContext); + String[] context = cg.getContext(sample, ngram, wSize, surroundingContext); Event ev = new Event(sense, context); events.add(ev); } while ((sample = samples.read()) != null); } - es = ObjectStreamUtils.createObjectStream(events); + final Map manifestInfoEntries = new HashMap<>(); + ObjectStream es = ObjectStreamUtils.createObjectStream(events); EventTrainer trainer = TrainerFactory.getEventTrainer(mlParams, manifestInfoEntries); + MaxentModel meModel = trainer.train(es); - meModel = trainer.train(es); - - return new WSDModel(lang, wordTag, defParams.windowSize, defParams.ngram, - meModel, surroundingContext, manifestInfoEntries); + return new WSDModel(lang, wordTag, wSize, ngram, meModel, surroundingContext, manifestInfoEntries); } private static List buildSurroundingContext(ObjectStream samples, int windowSize) throws IOException { - IMSWSDContextGenerator contextGenerator = new IMSWSDContextGenerator(); + IMSWSDContextGenerator cg = new IMSWSDContextGenerator(); List surroundingWordsModel = new ArrayList<>(); WSDSample sample; while ((sample = samples.read()) != null) { - String[] words = contextGenerator.extractSurroundingContext(sample.getTargetPosition(), + String[] words = cg.extractSurroundingContext(sample.getTargetPosition(), sample.getSentence(), sample.getLemmas(), windowSize); if (words.length > 0) { @@ -150,14 +148,17 @@ public String disambiguate(String[] tokenizedContext, String[] tokenTags, @Override public String disambiguate(WSDSample sample) { final WSDDefaultParameters defParams = ((WSDDefaultParameters) params); - final String wordTag = sample.getTargetWordTag(); + final int wSize = defParams.getIntParameter( + WSDDefaultParameters.WINDOW_SIZE_PARAM, WSDDefaultParameters.WINDOW_SIZE_DEFAULT); + final int ngram = defParams.getIntParameter( + WSDDefaultParameters.NGRAM_PARAM, WSDDefaultParameters.NGRAM_DEFAULT); + final String wordTag = sample.getTargetWordTag(); if (WSDHelper.isRelevantPOSTag(sample.getTargetTag())) { if (!model.getWordTag().equals(wordTag)) { return disambiguate(wordTag); } else { - String[] context = CONTEXT_GENERATOR.getContext(sample, - defParams.ngram, defParams.windowSize, this.model.getContextEntries()); + String[] context = cg.getContext(sample, wSize, ngram, this.model.getContextEntries()); double[] outcomeProbs = model.getWSDMaxentModel().eval(context); String outcome = model.getWSDMaxentModel().getBestOutcome(outcomeProbs); diff --git a/opennlp-wsd/src/test/java/opennlp/tools/cmdline/wsd/WSDModelLoaderTest.java b/opennlp-wsd/src/test/java/opennlp/tools/cmdline/wsd/WSDModelLoaderTest.java index d76a71ae..b7730204 100644 --- a/opennlp-wsd/src/test/java/opennlp/tools/cmdline/wsd/WSDModelLoaderTest.java +++ b/opennlp-wsd/src/test/java/opennlp/tools/cmdline/wsd/WSDModelLoaderTest.java @@ -36,6 +36,7 @@ import opennlp.tools.AbstractTest; import opennlp.tools.disambiguator.WSDDefaultParameters; +import opennlp.tools.disambiguator.WSDisambiguatorFactory; import opennlp.tools.disambiguator.WSDModel; import opennlp.tools.disambiguator.WSDSample; import opennlp.tools.disambiguator.WSDisambiguatorME; @@ -64,18 +65,25 @@ static void createSimpleWSDModel(@TempDir(cleanup = CleanupMode.ALWAYS) Path tmp Path workDir = tmpDir.resolve("models" + File.separatorChar); trainingDir = workDir.resolve("training" + File.separatorChar) .resolve("supervised" + File.separatorChar); + File folder = trainingDir.toFile(); + if (!folder.exists()) { + assertTrue(folder.mkdirs()); + } + final TrainingParameters params = TrainingParameters.defaultParams(); params.put(TrainingParameters.THREADS_PARAM, 4); - + final WSDDefaultParameters wsdParams = WSDDefaultParameters.defaultParams(); + wsdParams.putIfAbsent(WSDDefaultParameters.TRAINING_DIR_PARAM, trainingDir.toAbsolutePath().toString()); + + final WSDisambiguatorFactory factory = new WSDisambiguatorFactory(); final SemcorReaderExtended sr = new SemcorReaderExtended(SEMCOR_DIR); final ObjectStream samples = sr.getSemcorDataStream(WORD_TAG); try { - WSDDefaultParameters wsdParams = new WSDDefaultParameters(trainingDir); - trainedModel = WSDisambiguatorME.train("en", samples, params, wsdParams); + trainedModel = WSDisambiguatorME.train("en", samples, params, wsdParams, factory); assertNotNull(trainedModel); - File modelFile = new File(wsdParams.getTrainingDataDirectory() + - Character.toString(File.separatorChar) + WORD_TAG + ".wsd.model"); + File modelFile = new File(wsdParams.getStringParameter( + WSDDefaultParameters.TRAINING_DIR_PARAM, "") + File.separatorChar + WORD_TAG + ".wsd.model"); try (OutputStream modelOut = new BufferedOutputStream(new FileOutputStream(modelFile))) { trainedModel.serialize(modelOut); } diff --git a/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDDefaultParametersTest.java b/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDDefaultParametersTest.java index 691d3b33..8a019250 100644 --- a/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDDefaultParametersTest.java +++ b/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDDefaultParametersTest.java @@ -45,15 +45,20 @@ static void initEnv(@TempDir(cleanup = CleanupMode.ALWAYS) Path tmpDir) { @Test void testCreate() { - WSDDefaultParameters params = new WSDDefaultParameters(trainingDir); + WSDDefaultParameters params = WSDDefaultParameters.defaultParams(); + params.putIfAbsent(WSDDefaultParameters.TRAINING_DIR_PARAM, trainingDir.toAbsolutePath().toString()); assertNotNull(params); assertInstanceOf(WSDParameters.class, params); assertTrue(params.areValid()); - assertEquals(trainingDir, params.getTrainingDataDirectory()); - assertEquals(WSDDefaultParameters.DFLT_NGRAM, params.getNgram()); - assertEquals(WSDDefaultParameters.DFLT_WIN_SIZE, params.getWindowSize()); - assertEquals(WSDDefaultParameters.DFLT_LANG_CODE, params.getLanguageCode()); - assertEquals(WSDDefaultParameters.DFLT_SOURCE, params.getSenseSource()); + assertEquals(WSDDefaultParameters.NGRAM_DEFAULT, + params.getIntParameter(WSDDefaultParameters.NGRAM_PARAM, WSDDefaultParameters.NGRAM_DEFAULT)); + assertEquals(WSDDefaultParameters.WINDOW_SIZE_DEFAULT, + params.getIntParameter(WSDDefaultParameters.WINDOW_SIZE_PARAM, WSDDefaultParameters.WINDOW_SIZE_DEFAULT)); + assertEquals(WSDDefaultParameters.LANG_CODE_DEFAULT, + params.getStringParameter(WSDDefaultParameters.LANG_CODE, WSDDefaultParameters.LANG_CODE_DEFAULT)); + assertEquals(WSDDefaultParameters.SOURCE_DEFAULT, + WSDParameters.SenseSource.valueOf( + params.getStringParameter(WSDDefaultParameters.SENSE_SOURCE_PARAM, WSDDefaultParameters.SOURCE_DEFAULT.name()))); } } diff --git a/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDEvaluatorIT.java b/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDEvaluatorIT.java index e607325c..17085589 100644 --- a/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDEvaluatorIT.java +++ b/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDEvaluatorIT.java @@ -39,6 +39,7 @@ import opennlp.tools.util.TrainingParameters; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; class WSDEvaluatorIT extends AbstractEvaluatorTest { @@ -55,9 +56,15 @@ static void prepareEnv(@TempDir(cleanup = CleanupMode.ALWAYS) Path tmpDir) { Path workDir = tmpDir.resolve("models" + File.separatorChar); Path trainingDir = workDir.resolve("training" + File.separatorChar) .resolve("supervised" + File.separatorChar); - wsdParams = new WSDDefaultParameters(trainingDir); + File folder = trainingDir.toFile(); + if (!folder.exists()) { + assertTrue(folder.mkdirs()); + } + wsdParams = WSDDefaultParameters.defaultParams(); + wsdParams.putIfAbsent(WSDDefaultParameters.TRAINING_DIR_PARAM, trainingDir.toAbsolutePath().toString()); final SemcorReaderExtended seReader = new SemcorReaderExtended(SEMCOR_DIR); + final WSDisambiguatorFactory factory = new WSDisambiguatorFactory(); // train the models in parallel sampleTestWordMapping.keySet().parallelStream().forEach(word -> { @@ -71,12 +78,12 @@ static void prepareEnv(@TempDir(cleanup = CleanupMode.ALWAYS) Path tmpDir) { * file trained by semcor */ try { - final TrainingParameters trainingParams = TrainingParameters.defaultParams(); - trainingParams.put(TrainingParameters.THREADS_PARAM, 4); - WSDModel trained = WSDisambiguatorME.train("en", sampleStream, trainingParams, wsdParams); + final TrainingParameters params = TrainingParameters.defaultParams(); + params.put(TrainingParameters.THREADS_PARAM, 4); + WSDModel trained = WSDisambiguatorME.train("en", sampleStream, params, wsdParams, factory); assertNotNull(trained, "Checking the model to be written"); - File modelFile = new File(wsdParams.getTrainingDataDirectory() + - Character.toString(File.separatorChar) + word + ".wsd.model"); + File modelFile = new File(wsdParams.getStringParameter( + WSDDefaultParameters.TRAINING_DIR_PARAM, "") + File.separatorChar + word + ".wsd.model"); try (OutputStream modelOut = new BufferedOutputStream(new FileOutputStream(modelFile))) { trained.serialize(modelOut); } @@ -94,8 +101,9 @@ void testDisambiguationEval() { sampleTestWordMapping.keySet().parallelStream().forEach(word -> { // don't take verbs because they are not from WordNet if (!SPLIT.split(word)[1].equals("v")) { - File modelFile = new File(wsdParams.getTrainingDataDirectory() + - Character.toString(File.separatorChar) + word + ".wsd.model"); + File modelFile = new File(wsdParams.getStringParameter( + WSDDefaultParameters.TRAINING_DIR_PARAM, "") + + File.separatorChar + word + ".wsd.model"); WSDModel model = null; try { model = new WSDModel(modelFile); diff --git a/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDModelTest.java b/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDModelTest.java index 6e25ed1a..3c382b36 100644 --- a/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDModelTest.java +++ b/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDModelTest.java @@ -28,15 +28,15 @@ import java.io.OutputStream; import java.nio.file.Path; -import opennlp.tools.AbstractTest; -import opennlp.tools.ml.model.MaxentModel; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.CleanupMode; import org.junit.jupiter.api.io.TempDir; +import opennlp.tools.AbstractTest; import opennlp.tools.disambiguator.datareader.SemcorReaderExtended; import opennlp.tools.ml.maxent.GISModel; +import opennlp.tools.ml.model.MaxentModel; import opennlp.tools.util.ObjectStream; import opennlp.tools.util.TrainingParameters; @@ -61,20 +61,25 @@ class WSDModelTest extends AbstractTest { @BeforeAll static void createSimpleWSDModel(@TempDir(cleanup = CleanupMode.ALWAYS) Path tmpDir) { - Path workDir = tmpDir.resolve("models" + File.separatorChar); trainingDir = workDir.resolve("training" + File.separatorChar) .resolve("supervised" + File.separatorChar); + File folder = trainingDir.toFile(); + if (!folder.exists()) { + assertTrue(folder.mkdirs()); + } + final TrainingParameters params = TrainingParameters.defaultParams(); params.put(TrainingParameters.THREADS_PARAM, 4); + final WSDDefaultParameters wsdParams = WSDDefaultParameters.defaultParams(); + wsdParams.putIfAbsent(WSDDefaultParameters.TRAINING_DIR_PARAM, trainingDir.toAbsolutePath().toString()); + final WSDisambiguatorFactory factory = new WSDisambiguatorFactory(); final SemcorReaderExtended sr = new SemcorReaderExtended(SEMCOR_DIR); - final ObjectStream samples = sr.getSemcorDataStream(WORD_TAG); try { - trainedModel = WSDisambiguatorME.train("en", samples, params, - new WSDDefaultParameters(trainingDir)); + trainedModel = WSDisambiguatorME.train("en", samples, params, wsdParams, factory); assertNotNull(trainedModel); } catch (IOException e1) { fail("Exception in training: " + e1.getMessage()); diff --git a/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDisambiguatorMETest.java b/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDisambiguatorMETest.java index 09b84c79..877f48a1 100644 --- a/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDisambiguatorMETest.java +++ b/opennlp-wsd/src/test/java/opennlp/tools/disambiguator/WSDisambiguatorMETest.java @@ -37,6 +37,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; /** @@ -66,8 +67,15 @@ static void prepareEnv(@TempDir(cleanup = CleanupMode.ALWAYS) Path tmpDir) { Path workDir = tmpDir.resolve("models" + File.separatorChar); Path trainingDir = workDir.resolve("training" + File.separatorChar) .resolve("supervised" + File.separatorChar); - wsdParams = new WSDDefaultParameters(trainingDir); + File folder = trainingDir.toFile(); + if (!folder.exists()) { + assertTrue(folder.mkdirs()); + } + wsdParams = WSDDefaultParameters.defaultParams(); + wsdParams.putIfAbsent(WSDDefaultParameters.TRAINING_DIR_PARAM, trainingDir.toAbsolutePath().toString()); + final TrainingParameters trainingParams = TrainingParameters.defaultParams(); + final WSDisambiguatorFactory factory = new WSDisambiguatorFactory(); final SemcorReaderExtended sr = new SemcorReaderExtended(SEMCOR_DIR); final String test = "please.v"; @@ -78,7 +86,7 @@ static void prepareEnv(@TempDir(cleanup = CleanupMode.ALWAYS) Path tmpDir) { * We test both writing and reading a model file trained by semcor */ try { - model= WSDisambiguatorME.train("en", sampleStream, trainingParams, wsdParams); + model= WSDisambiguatorME.train("en", sampleStream, trainingParams, wsdParams, factory); assertNotNull(model, "Checking the model"); } catch (IOException e1) { fail("Exception in training: " + e1.getMessage());