From 20f9c1d01fa61be6e349f563532ea7c30d578535 Mon Sep 17 00:00:00 2001 From: Steve Ash Date: Wed, 18 Feb 2015 21:43:51 -0600 Subject: [PATCH] adding ability to initialize starting values for parameters from another model taking whatever matches --- src/cc/mallet/fst/CRF.java | 52 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/src/cc/mallet/fst/CRF.java b/src/cc/mallet/fst/CRF.java index e114050ed..a9de83833 100644 --- a/src/cc/mallet/fst/CRF.java +++ b/src/cc/mallet/fst/CRF.java @@ -1558,6 +1558,58 @@ public void induceFeaturesFor (InstanceList instances) { } } + /** + * This will take a CRF (that might have different structure) and try to apply any starting + * point weight values that it can. It will match everything using the actual input + * Alphabet and weightAlphabet names -- so it doesn't matter if the features have changed- it + * will just take as much as it can as a starting point. + * You would want to call this _after_ you have setup all of the states and dimensions + * @param startingPoint + */ + public void initializeApplicableParametersFrom(CRF startingPoint) { + int stateCount = 0; + int transitionCount = 0; + int featureCount = 0; + for (int i = 0; i < this.states.size(); i++) { + State thisState = this.states.get(i); + State thatState = startingPoint.getState(thisState.getName()); + if (thatState == null) continue; + + parameters.initialWeights[thisState.index] = + startingPoint.parameters.initialWeights[thatState.index]; + parameters.finalWeights[thisState.index] = + startingPoint.parameters.finalWeights[thatState.index]; + stateCount += 1; + } + + for (int i = 0; i < parameters.weightAlphabet.size(); i++) { + Object weightKey = parameters.weightAlphabet.lookupObject(i); + int spIndex = startingPoint.parameters.weightAlphabet.lookupIndex(weightKey, false); + if (spIndex < 0) continue; + + transitionCount += 1; + this.parameters.defaultWeights[i] = startingPoint.parameters.defaultWeights[spIndex]; + + SparseVector thisFe = this.parameters.weights[i]; + SparseVector thatFe = startingPoint.parameters.weights[spIndex]; + for (int j = 0; j < thisFe.numLocations(); j++) { + int thisIndex = thisFe.indexAtLocation(j); + Object thisFeature = this.inputAlphabet.lookupObject(thisIndex); + int thatIndex = startingPoint.inputAlphabet.lookupIndex(thisFeature, false); + if (thatIndex < 0) continue; + + double thatValue = thatFe.value(thatIndex); + if (thatValue != 0) { + thisFe.setValueAtLocation(j, thatValue); + featureCount += 1; + } + } + } + weightsValueChanged(); + logger.info("Finished intiailizing from previous model: matched " + transitionCount + + " transitions, " + stateCount + " states, and " + featureCount + " features"); + } + // TODO Put support to Optimizable here, including getValue(InstanceList)?? public void print ()