ObservationType
- Type of observations handled by the mixture model@PublicationReference(author="Radform M. Neal",title="Markov Chain Sampling Methods for Dirichlet Process Mixture Models",type=Journal,year=2000,publication="Journal of Computational and Graphical Statistics, Vol. 9, No. 2",pages={249,265},notes="Based in part on Algorithm 2 from Neal") @PublicationReference(author={"Michael D. Escobar","Mike West"},title="Bayesian Density Estimation and Inference Using Mixtures",type=Journal,publication="Journal of the American Statistical Association",year=1995) public class DirichletProcessMixtureModel<ObservationType> extends AbstractMarkovChainMonteCarlo<ObservationType,DirichletProcessMixtureModel.Sample<ObservationType>>
Modifier and Type | Class and Description |
---|---|
static class |
DirichletProcessMixtureModel.DPMMCluster<ObservationType>
Cluster for a step in the DPMM
|
protected static class |
DirichletProcessMixtureModel.DPMMLogConditional
Container for the log conditional likelihood
|
static class |
DirichletProcessMixtureModel.MultivariateMeanCovarianceUpdater
Updater that creates specified clusters with distinct means and covariances
|
static class |
DirichletProcessMixtureModel.MultivariateMeanUpdater
Updater that creates specified clusters with identical covariances
|
static class |
DirichletProcessMixtureModel.Sample<ObservationType>
A sample from the Dirichlet Process Mixture Model.
|
static interface |
DirichletProcessMixtureModel.Updater<ObservationType>
Updater for the DPMM
|
Modifier and Type | Field and Description |
---|---|
protected GammaDistribution |
alphaInverseSampler
Samples a new alpha-inverse.
|
protected double[] |
clusterWeights
Holds the cluster weights so that we don't have to re-allocate them
each mcmcUpdate step.
|
protected ProbabilityFunction<ObservationType> |
conditionalPriorPredictive
Base predictive distribution that determines the value of the
new cluster weighting during the Gibbs sampling.
|
static double |
DEFAULT_ALPHA
Default concentration parameter of the Dirichlet Process, 1.0.
|
static int |
DEFAULT_NUM_INITIAL_CLUSTERS
Default number of initial clusters
|
static boolean |
DEFAULT_REESTIMATE_ALPHA
The default value for re-estimating alpha is true.
|
protected BetaDistribution |
etaSampler
Creates a new value of "eta" which, in turn, helps sample a new alpha.
|
protected double |
initialAlpha
Initial value of alpha, the concentration parameter of the
Dirichlet Process
|
protected boolean |
reestimateAlpha
Flag to automatically re-estimate the alpha parameter
|
protected DirichletProcessMixtureModel.Updater<ObservationType> |
updater
Creates the clusters and predictive prior distributions
|
currentParameter, DEFAULT_NUM_SAMPLES, previousParameter, random
data, keepGoing
maxIterations
DEFAULT_ITERATION, iteration
Constructor and Description |
---|
DirichletProcessMixtureModel()
Creates a new instance of DirichletProcessMixtureModel
|
Modifier and Type | Method and Description |
---|---|
protected java.util.ArrayList<java.util.Collection<ObservationType>> |
assignObservationsToClusters(int K,
DirichletProcessMixtureModel.DPMMLogConditional logConditional)
Assigns observations to each of the K clusters,
plus the as-yet-uncreated new cluster
|
protected int |
assignObservationToCluster(ObservationType observation,
double[] weights,
DirichletProcessMixtureModel.DPMMLogConditional logConditional)
Probabilistically assigns an observation to a cluster
|
DirichletProcessMixtureModel<ObservationType> |
clone()
This makes public the clone method on the
Object class and
removes the exception that it throws. |
protected DirichletProcessMixtureModel.DPMMCluster<ObservationType> |
createCluster(java.util.Collection<ObservationType> clusterAssignment,
DirichletProcessMixtureModel.Updater<ObservationType> localUpdater)
Creates a cluster from the given cluster assignment
|
DirichletProcessMixtureModel.Sample<ObservationType> |
createInitialLearnedObject()
Creates the initial parameters from which to start the Markov chain.
|
double |
getInitialAlpha()
Getter for initialAlpha
|
int |
getNumInitialClusters()
Getter for numInitialClusters
|
boolean |
getReestimateAlpha()
Getter for reestimateAlpha
|
DirichletProcessMixtureModel.Updater<ObservationType> |
getUpdater()
Getter for updater
|
protected void |
mcmcUpdate()
Performs a valid MCMC update step.
|
void |
setInitialAlpha(double initialAlpha)
Setter for initialAlpha
|
void |
setNumInitialClusters(int numInitialClusters)
Getter for numInitialClusters
|
void |
setReestimateAlpha(boolean reestimateAlpha)
Setter for reestimateAlpha
|
void |
setUpdater(DirichletProcessMixtureModel.Updater<ObservationType> updater)
Setter for updater
|
protected double |
updateAlpha(double alpha,
int numObservations)
Runs the Gibbs sampler for the concentration parameter, alpha, given
the data.
|
protected java.util.ArrayList<DirichletProcessMixtureModel.DPMMCluster<ObservationType>> |
updateClusters(java.util.ArrayList<java.util.Collection<ObservationType>> clusterAssignments)
Update each cluster according to the data assigned to it
|
cleanupAlgorithm, getBurnInIterations, getCurrentParameter, getIterationsPerSample, getPreviousParameter, getRandom, getResult, initializeAlgorithm, setBurnInIterations, setCurrentParameter, setIterationsPerSample, setRandom, setResult, step
getData, getKeepGoing, learn, setData, setKeepGoing, stop
getMaxIterations, isResultValid, setMaxIterations
addIterativeAlgorithmListener, fireAlgorithmEnded, fireAlgorithmStarted, fireStepEnded, fireStepStarted, getIteration, getListeners, removeIterativeAlgorithmListener, setIteration, setListeners
equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
learn
getMaxIterations, setMaxIterations
addIterativeAlgorithmListener, getIteration, removeIterativeAlgorithmListener
isResultValid, stop
public static final double DEFAULT_ALPHA
public static final int DEFAULT_NUM_INITIAL_CLUSTERS
public static final boolean DEFAULT_REESTIMATE_ALPHA
protected DirichletProcessMixtureModel.Updater<ObservationType> updater
protected boolean reestimateAlpha
protected double initialAlpha
protected transient ProbabilityFunction<ObservationType> conditionalPriorPredictive
protected transient double[] clusterWeights
protected transient BetaDistribution etaSampler
protected transient GammaDistribution alphaInverseSampler
public DirichletProcessMixtureModel()
public DirichletProcessMixtureModel<ObservationType> clone()
AbstractCloneableSerializable
Object
class and
removes the exception that it throws. Its default behavior is to
automatically create a clone of the exact type of object that the
clone is called on and to copy all primitives but to keep all references,
which means it is a shallow copy.
Extensions of this class may want to override this method (but call
super.clone()
to implement a "smart copy". That is, to target
the most common use case for creating a copy of the object. Because of
the default behavior being a shallow copy, extending classes only need
to handle fields that need to have a deeper copy (or those that need to
be reset). Some of the methods in ObjectUtil
may be helpful in
implementing a custom clone method.
Note: The contract of this method is that you must use
super.clone()
as the basis for your implementation.clone
in interface CloneableSerializable
clone
in class AbstractMarkovChainMonteCarlo<ObservationType,DirichletProcessMixtureModel.Sample<ObservationType>>
protected void mcmcUpdate()
AbstractMarkovChainMonteCarlo
mcmcUpdate
in class AbstractMarkovChainMonteCarlo<ObservationType,DirichletProcessMixtureModel.Sample<ObservationType>>
protected java.util.ArrayList<DirichletProcessMixtureModel.DPMMCluster<ObservationType>> updateClusters(java.util.ArrayList<java.util.Collection<ObservationType>> clusterAssignments)
clusterAssignments
- Observations assigned to each clusterprotected java.util.ArrayList<java.util.Collection<ObservationType>> assignObservationsToClusters(int K, DirichletProcessMixtureModel.DPMMLogConditional logConditional)
K
- Number of clusterslogConditional
- The log of the conditional.protected int assignObservationToCluster(ObservationType observation, double[] weights, DirichletProcessMixtureModel.DPMMLogConditional logConditional)
observation
- Observation that we're assigningweights
- Place holder for the weights that this method will createlogConditional
- The log of the conditional.protected DirichletProcessMixtureModel.DPMMCluster<ObservationType> createCluster(java.util.Collection<ObservationType> clusterAssignment, DirichletProcessMixtureModel.Updater<ObservationType> localUpdater)
clusterAssignment
- Observations assigned to a particular clusterlocalUpdater
- Updater that recomputes the cluster parameters, needed to ensure
thread safety in the parallel implementationprotected double updateAlpha(double alpha, int numObservations)
alpha
- Current value of the concentration parameternumObservations
- Number of observations we're sampling overpublic DirichletProcessMixtureModel.Sample<ObservationType> createInitialLearnedObject()
AbstractMarkovChainMonteCarlo
createInitialLearnedObject
in class AbstractMarkovChainMonteCarlo<ObservationType,DirichletProcessMixtureModel.Sample<ObservationType>>
public DirichletProcessMixtureModel.Updater<ObservationType> getUpdater()
public void setUpdater(DirichletProcessMixtureModel.Updater<ObservationType> updater)
updater
- Creates the clusters and predictive prior distributionspublic int getNumInitialClusters()
public void setNumInitialClusters(int numInitialClusters)
numInitialClusters
- Number of clusters to initializepublic boolean getReestimateAlpha()
public void setReestimateAlpha(boolean reestimateAlpha)
reestimateAlpha
- Flag to automatically re-estimate the alpha parameterpublic double getInitialAlpha()
public void setInitialAlpha(double initialAlpha)
initialAlpha
- Initial value of alpha, the concentration parameter of the
Dirichlet Process