@PublicationReference(author="Jaakkola", title="Estimating mixtures: the EM-algorithm", type=Misc, year=2007, url="http://courses.csail.mit.edu/6.867/lectures/notes-em2.pdf") public static class MixtureOfGaussians.EMLearner extends AbstractAnytimeBatchLearner<java.util.Collection<? extends Vector>,MixtureOfGaussians.PDF> implements Randomized, DistributionEstimator<Vector,MixtureOfGaussians.PDF>, MeasurablePerformanceAlgorithm
Modifier and Type | Field and Description |
---|---|
static int |
DEFAULT_MAX_ITERATIONS
Default max iterations, 100.
|
static double |
DEFAULT_TOLERANCE
Default tolerance, 1.0E-5.
|
static java.lang.String |
PERFORMANCE_NAME
Name of the performance measurement, "Assignment Change".
|
protected java.util.Random |
random
Random number generator.
|
data, keepGoing
maxIterations
DEFAULT_ITERATION, iteration
Constructor and Description |
---|
EMLearner(int distributionCount,
MultivariateGaussian.WeightedMaximumLikelihoodEstimator learner,
java.util.Random random)
Creates a new instance of EMLearner
|
EMLearner(int distributionCount,
java.util.Random random)
Creates a new instance of EMLearner
|
EMLearner(java.util.Random random)
Creates a new instance of EMLearner
|
Modifier and Type | Method and Description |
---|---|
protected void |
cleanupAlgorithm()
Called to clean up the learning algorithm's state after learning has
finished.
|
double |
getAssignmentChanged()
Gets the total assignment change from the last completed step of
the algorithm.
|
NamedValue<java.lang.Double> |
getPerformance()
Gets the name-value pair that describes the current performance of the
algorithm.
|
java.util.Random |
getRandom()
Gets the random number generator used by this object.
|
MixtureOfGaussians.PDF |
getResult()
Gets the current result of the algorithm.
|
double |
getTolerance()
Getter for tolerance
|
protected boolean |
initializeAlgorithm()
Called to initialize the learning algorithm's state based on the
data that is stored in the data field.
|
void |
setRandom(java.util.Random random)
Sets the random number generator used by this object.
|
void |
setTolerance(double tolerance)
Setter for tolerance
|
protected boolean |
step()
Called to take a single step of the learning algorithm.
|
clone, getData, getKeepGoing, learn, setData, setKeepGoing, stop
getMaxIterations, isResultValid, setMaxIterations
addIterativeAlgorithmListener, fireAlgorithmEnded, fireAlgorithmStarted, fireStepEnded, fireStepStarted, getIteration, getListeners, removeIterativeAlgorithmListener, setIteration, setListeners
equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
learn
clone
getMaxIterations, setMaxIterations
addIterativeAlgorithmListener, getIteration, removeIterativeAlgorithmListener
isResultValid
public static final java.lang.String PERFORMANCE_NAME
public static final int DEFAULT_MAX_ITERATIONS
public static final double DEFAULT_TOLERANCE
protected java.util.Random random
public EMLearner(java.util.Random random)
random
- Random number generatorpublic EMLearner(int distributionCount, java.util.Random random)
distributionCount
- Number of distributions in the mixturerandom
- Random number generatorpublic EMLearner(int distributionCount, MultivariateGaussian.WeightedMaximumLikelihoodEstimator learner, java.util.Random random)
distributionCount
- Number of distributions in the mixturelearner
- Learner used to reestimate the componentsrandom
- Random number generatorprotected boolean initializeAlgorithm()
AbstractAnytimeBatchLearner
initializeAlgorithm
in class AbstractAnytimeBatchLearner<java.util.Collection<? extends Vector>,MixtureOfGaussians.PDF>
protected boolean step()
AbstractAnytimeBatchLearner
step
in class AbstractAnytimeBatchLearner<java.util.Collection<? extends Vector>,MixtureOfGaussians.PDF>
protected void cleanupAlgorithm()
AbstractAnytimeBatchLearner
cleanupAlgorithm
in class AbstractAnytimeBatchLearner<java.util.Collection<? extends Vector>,MixtureOfGaussians.PDF>
public MixtureOfGaussians.PDF getResult()
AnytimeAlgorithm
getResult
in interface AnytimeAlgorithm<MixtureOfGaussians.PDF>
public NamedValue<java.lang.Double> getPerformance()
MeasurablePerformanceAlgorithm
getPerformance
in interface MeasurablePerformanceAlgorithm
public double getTolerance()
public void setTolerance(double tolerance)
tolerance
- Tolerance before stopping, must be greater than or equal to 0public java.util.Random getRandom()
Randomized
getRandom
in interface Randomized
public void setRandom(java.util.Random random)
Randomized
setRandom
in interface Randomized
random
- The random number generator for this object to use.public double getAssignmentChanged()