ObservationType - Type of observations handled by the mixture model@PublicationReference(author="Radform M. Neal",title="Markov Chain Sampling Methods for Dirichlet Process Mixture Models",type=Journal,year=2000,publication="Journal of Computational and Graphical Statistics, Vol. 9, No. 2",pages={249,265},notes="Based in part on Algorithm 2 from Neal") @PublicationReference(author={"Michael D. Escobar","Mike West"},title="Bayesian Density Estimation and Inference Using Mixtures",type=Journal,publication="Journal of the American Statistical Association",year=1995) public class DirichletProcessMixtureModel<ObservationType> extends AbstractMarkovChainMonteCarlo<ObservationType,DirichletProcessMixtureModel.Sample<ObservationType>>
| Modifier and Type | Class and Description |
|---|---|
static class |
DirichletProcessMixtureModel.DPMMCluster<ObservationType>
Cluster for a step in the DPMM
|
protected static class |
DirichletProcessMixtureModel.DPMMLogConditional
Container for the log conditional likelihood
|
static class |
DirichletProcessMixtureModel.MultivariateMeanCovarianceUpdater
Updater that creates specified clusters with distinct means and covariances
|
static class |
DirichletProcessMixtureModel.MultivariateMeanUpdater
Updater that creates specified clusters with identical covariances
|
static class |
DirichletProcessMixtureModel.Sample<ObservationType>
A sample from the Dirichlet Process Mixture Model.
|
static interface |
DirichletProcessMixtureModel.Updater<ObservationType>
Updater for the DPMM
|
| Modifier and Type | Field and Description |
|---|---|
protected GammaDistribution |
alphaInverseSampler
Samples a new alpha-inverse.
|
protected double[] |
clusterWeights
Holds the cluster weights so that we don't have to re-allocate them
each mcmcUpdate step.
|
protected ProbabilityFunction<ObservationType> |
conditionalPriorPredictive
Base predictive distribution that determines the value of the
new cluster weighting during the Gibbs sampling.
|
static double |
DEFAULT_ALPHA
Default concentration parameter of the Dirichlet Process, 1.0.
|
static int |
DEFAULT_NUM_INITIAL_CLUSTERS
Default number of initial clusters
|
static boolean |
DEFAULT_REESTIMATE_ALPHA
The default value for re-estimating alpha is true.
|
protected BetaDistribution |
etaSampler
Creates a new value of "eta" which, in turn, helps sample a new alpha.
|
protected double |
initialAlpha
Initial value of alpha, the concentration parameter of the
Dirichlet Process
|
protected boolean |
reestimateAlpha
Flag to automatically re-estimate the alpha parameter
|
protected DirichletProcessMixtureModel.Updater<ObservationType> |
updater
Creates the clusters and predictive prior distributions
|
currentParameter, DEFAULT_NUM_SAMPLES, previousParameter, randomdata, keepGoingmaxIterationsDEFAULT_ITERATION, iteration| Constructor and Description |
|---|
DirichletProcessMixtureModel()
Creates a new instance of DirichletProcessMixtureModel
|
| Modifier and Type | Method and Description |
|---|---|
protected java.util.ArrayList<java.util.Collection<ObservationType>> |
assignObservationsToClusters(int K,
DirichletProcessMixtureModel.DPMMLogConditional logConditional)
Assigns observations to each of the K clusters,
plus the as-yet-uncreated new cluster
|
protected int |
assignObservationToCluster(ObservationType observation,
double[] weights,
DirichletProcessMixtureModel.DPMMLogConditional logConditional)
Probabilistically assigns an observation to a cluster
|
DirichletProcessMixtureModel<ObservationType> |
clone()
This makes public the clone method on the
Object class and
removes the exception that it throws. |
protected DirichletProcessMixtureModel.DPMMCluster<ObservationType> |
createCluster(java.util.Collection<ObservationType> clusterAssignment,
DirichletProcessMixtureModel.Updater<ObservationType> localUpdater)
Creates a cluster from the given cluster assignment
|
DirichletProcessMixtureModel.Sample<ObservationType> |
createInitialLearnedObject()
Creates the initial parameters from which to start the Markov chain.
|
double |
getInitialAlpha()
Getter for initialAlpha
|
int |
getNumInitialClusters()
Getter for numInitialClusters
|
boolean |
getReestimateAlpha()
Getter for reestimateAlpha
|
DirichletProcessMixtureModel.Updater<ObservationType> |
getUpdater()
Getter for updater
|
protected void |
mcmcUpdate()
Performs a valid MCMC update step.
|
void |
setInitialAlpha(double initialAlpha)
Setter for initialAlpha
|
void |
setNumInitialClusters(int numInitialClusters)
Getter for numInitialClusters
|
void |
setReestimateAlpha(boolean reestimateAlpha)
Setter for reestimateAlpha
|
void |
setUpdater(DirichletProcessMixtureModel.Updater<ObservationType> updater)
Setter for updater
|
protected double |
updateAlpha(double alpha,
int numObservations)
Runs the Gibbs sampler for the concentration parameter, alpha, given
the data.
|
protected java.util.ArrayList<DirichletProcessMixtureModel.DPMMCluster<ObservationType>> |
updateClusters(java.util.ArrayList<java.util.Collection<ObservationType>> clusterAssignments)
Update each cluster according to the data assigned to it
|
cleanupAlgorithm, getBurnInIterations, getCurrentParameter, getIterationsPerSample, getPreviousParameter, getRandom, getResult, initializeAlgorithm, setBurnInIterations, setCurrentParameter, setIterationsPerSample, setRandom, setResult, stepgetData, getKeepGoing, learn, setData, setKeepGoing, stopgetMaxIterations, isResultValid, setMaxIterationsaddIterativeAlgorithmListener, fireAlgorithmEnded, fireAlgorithmStarted, fireStepEnded, fireStepStarted, getIteration, getListeners, removeIterativeAlgorithmListener, setIteration, setListenersequals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitlearngetMaxIterations, setMaxIterationsaddIterativeAlgorithmListener, getIteration, removeIterativeAlgorithmListenerisResultValid, stoppublic static final double DEFAULT_ALPHA
public static final int DEFAULT_NUM_INITIAL_CLUSTERS
public static final boolean DEFAULT_REESTIMATE_ALPHA
protected DirichletProcessMixtureModel.Updater<ObservationType> updater
protected boolean reestimateAlpha
protected double initialAlpha
protected transient ProbabilityFunction<ObservationType> conditionalPriorPredictive
protected transient double[] clusterWeights
protected transient BetaDistribution etaSampler
protected transient GammaDistribution alphaInverseSampler
public DirichletProcessMixtureModel()
public DirichletProcessMixtureModel<ObservationType> clone()
AbstractCloneableSerializableObject class and
removes the exception that it throws. Its default behavior is to
automatically create a clone of the exact type of object that the
clone is called on and to copy all primitives but to keep all references,
which means it is a shallow copy.
Extensions of this class may want to override this method (but call
super.clone() to implement a "smart copy". That is, to target
the most common use case for creating a copy of the object. Because of
the default behavior being a shallow copy, extending classes only need
to handle fields that need to have a deeper copy (or those that need to
be reset). Some of the methods in ObjectUtil may be helpful in
implementing a custom clone method.
Note: The contract of this method is that you must use
super.clone() as the basis for your implementation.clone in interface CloneableSerializableclone in class AbstractMarkovChainMonteCarlo<ObservationType,DirichletProcessMixtureModel.Sample<ObservationType>>protected void mcmcUpdate()
AbstractMarkovChainMonteCarlomcmcUpdate in class AbstractMarkovChainMonteCarlo<ObservationType,DirichletProcessMixtureModel.Sample<ObservationType>>protected java.util.ArrayList<DirichletProcessMixtureModel.DPMMCluster<ObservationType>> updateClusters(java.util.ArrayList<java.util.Collection<ObservationType>> clusterAssignments)
clusterAssignments - Observations assigned to each clusterprotected java.util.ArrayList<java.util.Collection<ObservationType>> assignObservationsToClusters(int K, DirichletProcessMixtureModel.DPMMLogConditional logConditional)
K - Number of clusterslogConditional - The log of the conditional.protected int assignObservationToCluster(ObservationType observation, double[] weights, DirichletProcessMixtureModel.DPMMLogConditional logConditional)
observation - Observation that we're assigningweights - Place holder for the weights that this method will createlogConditional - The log of the conditional.protected DirichletProcessMixtureModel.DPMMCluster<ObservationType> createCluster(java.util.Collection<ObservationType> clusterAssignment, DirichletProcessMixtureModel.Updater<ObservationType> localUpdater)
clusterAssignment - Observations assigned to a particular clusterlocalUpdater - Updater that recomputes the cluster parameters, needed to ensure
thread safety in the parallel implementationprotected double updateAlpha(double alpha,
int numObservations)
alpha - Current value of the concentration parameternumObservations - Number of observations we're sampling overpublic DirichletProcessMixtureModel.Sample<ObservationType> createInitialLearnedObject()
AbstractMarkovChainMonteCarlocreateInitialLearnedObject in class AbstractMarkovChainMonteCarlo<ObservationType,DirichletProcessMixtureModel.Sample<ObservationType>>public DirichletProcessMixtureModel.Updater<ObservationType> getUpdater()
public void setUpdater(DirichletProcessMixtureModel.Updater<ObservationType> updater)
updater - Creates the clusters and predictive prior distributionspublic int getNumInitialClusters()
public void setNumInitialClusters(int numInitialClusters)
numInitialClusters - Number of clusters to initializepublic boolean getReestimateAlpha()
public void setReestimateAlpha(boolean reestimateAlpha)
reestimateAlpha - Flag to automatically re-estimate the alpha parameterpublic double getInitialAlpha()
public void setInitialAlpha(double initialAlpha)
initialAlpha - Initial value of alpha, the concentration parameter of the
Dirichlet Process