ResultType
- Type of result to expect, such as GradientDescendableCostFunctionType
- Type of cost function to use, such as
SumSquaredErrorCostFunctionpublic abstract class AbstractParameterCostMinimizer<ResultType extends VectorizableVectorFunction,CostFunctionType extends SupervisedCostFunction<Vector,Vector>> extends AbstractAnytimeSupervisedBatchLearner<Vector,Vector,ResultType> implements BatchCostMinimizationLearner<java.util.Collection<? extends InputOutputPair<? extends Vector,Vector>>,ResultType>, ParameterCostMinimizer<ResultType>
Modifier and Type | Field and Description |
---|---|
static int |
DEFAULT_MAX_ITERATIONS
Default maximum number of iterations before stopping 1000
|
static double |
DEFAULT_TOLERANCE
Default convergence criterion 1.0E-7
|
data, keepGoing
maxIterations
DEFAULT_ITERATION, iteration
Constructor and Description |
---|
AbstractParameterCostMinimizer(CostFunctionType costFunction,
int maxIterations,
double tolerance)
Creates a new instance of AbstractParameterCostMinimizer
|
Modifier and Type | Method and Description |
---|---|
CostFunctionType |
getCostFunction()
Gets the cost function that the learner is minimizing.
|
ResultType |
getObjectToOptimize()
Getter for objectToOptimize
|
NamedValue<java.lang.Double> |
getPerformance()
Gets the name-value pair that describes the current performance of the
algorithm.
|
ResultType |
getResult()
Gets the current result of the algorithm.
|
protected java.lang.Double |
getResultCost()
Getter for resultCost
|
double |
getTolerance()
Getter for tolerance
|
void |
setCostFunction(CostFunctionType costFunction)
Setter for costFunction
|
void |
setObjectToOptimize(ResultType objectToOptimize)
Setter for objectToOptimize
|
protected void |
setResult(ResultType result)
Setter for result
|
protected void |
setResultCost(java.lang.Double resultCost)
Setter for resultCost
|
void |
setTolerance(double tolerance)
Setter for tolerance
|
cleanupAlgorithm, clone, getData, getKeepGoing, initializeAlgorithm, learn, setData, setKeepGoing, step, stop
getMaxIterations, isResultValid, setMaxIterations
addIterativeAlgorithmListener, fireAlgorithmEnded, fireAlgorithmStarted, fireStepEnded, fireStepStarted, getIteration, getListeners, removeIterativeAlgorithmListener, setIteration, setListeners
equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
learn
clone
getMaxIterations, setMaxIterations
addIterativeAlgorithmListener, getIteration, removeIterativeAlgorithmListener
isResultValid, stop
public static final double DEFAULT_TOLERANCE
public static final int DEFAULT_MAX_ITERATIONS
public AbstractParameterCostMinimizer(CostFunctionType costFunction, int maxIterations, double tolerance)
costFunction
- Cost function that computes the cost of the object to optimizemaxIterations
- Maximum number of iterations before stoppingtolerance
- Stopping criterion for the algorithm, typically ~1e-5public ResultType getObjectToOptimize()
getObjectToOptimize
in interface ParameterCostMinimizer<ResultType extends VectorizableVectorFunction>
public void setObjectToOptimize(ResultType objectToOptimize)
setObjectToOptimize
in interface ParameterCostMinimizer<ResultType extends VectorizableVectorFunction>
objectToOptimize
- Vectorizable whose parameters result minimize the cost functionpublic ResultType getResult()
AnytimeAlgorithm
getResult
in interface AnytimeAlgorithm<ResultType extends VectorizableVectorFunction>
protected void setResult(ResultType result)
result
- Result to returnpublic double getTolerance()
public void setTolerance(double tolerance)
tolerance
- Stopping criterion for the algorithm, typically ~1e-5public CostFunctionType getCostFunction()
BatchCostMinimizationLearner
getCostFunction
in interface BatchCostMinimizationLearner<java.util.Collection<? extends InputOutputPair<? extends Vector,Vector>>,ResultType extends VectorizableVectorFunction>
public void setCostFunction(CostFunctionType costFunction)
costFunction
- Cost function that computes the cost of the object to optimizeprotected java.lang.Double getResultCost()
protected void setResultCost(java.lang.Double resultCost)
resultCost
- Cost of the resultpublic NamedValue<java.lang.Double> getPerformance()
MeasurablePerformanceAlgorithm
getPerformance
in interface MeasurablePerformanceAlgorithm