public class ParallelizedCostFunctionContainer extends AbstractSupervisedCostFunction<Vector,Vector> implements DifferentiableCostFunction, ParallelAlgorithm
| Modifier and Type | Class and Description |
|---|---|
protected static class |
ParallelizedCostFunctionContainer.SubCostEvaluate
Callable task for the evaluate() method.
|
protected static class |
ParallelizedCostFunctionContainer.SubCostGradient
Callable task for the computeGradient() method
|
| Constructor and Description |
|---|
ParallelizedCostFunctionContainer()
Default constructor for ParallelizedCostFunctionContainer.
|
ParallelizedCostFunctionContainer(ParallelizableCostFunction costFunction)
Creates a new instance of ParallelizedCostFunctionContainer
|
ParallelizedCostFunctionContainer(ParallelizableCostFunction costFunction,
java.util.concurrent.ThreadPoolExecutor threadPool)
Creates a new instance of ParallelizedCostFunctionContainer
|
| Modifier and Type | Method and Description |
|---|---|
ParallelizedCostFunctionContainer |
clone()
This makes public the clone method on the
Object class and
removes the exception that it throws. |
Vector |
computeParameterGradient(GradientDescendable function)
Differentiates function with respect to its parameters.
|
protected void |
createPartitions()
Splits the data across the numComponents cost functions
|
protected void |
createThreadPool()
Creates the thread pool using the Foundry's global thread pool.
|
java.lang.Double |
evaluate(Evaluator<? super Vector,? extends Vector> evaluator)
Computes the cost of the given target.
|
java.lang.Double |
evaluatePerformance(java.util.Collection<? extends TargetEstimatePair<? extends Vector,? extends Vector>> data)
Evaluates the performance accuracy of the given estimates against the
given targets.
|
ParallelizableCostFunction |
getCostFunction()
Getter for costFunction
|
int |
getNumThreads()
Gets the number of threads in the thread pool.
|
java.util.concurrent.ThreadPoolExecutor |
getThreadPool()
Gets the thread pool for the algorithm to use.
|
void |
setCostFunction(ParallelizableCostFunction costFunction)
Setter for costFunction
|
void |
setCostParameters(java.util.Collection<? extends InputOutputPair<? extends Vector,Vector>> costParameters)
Sets the parameters of the cost function used to evaluate the cost of
a target.
|
void |
setThreadPool(java.util.concurrent.ThreadPoolExecutor threadPool)
Sets the thread pool for the algorithm to use.
|
getCostParameters, summarizeevaluatePerformanceequals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitgetCostParameterssummarizeevaluatePerformancepublic ParallelizedCostFunctionContainer()
public ParallelizedCostFunctionContainer(ParallelizableCostFunction costFunction)
costFunction - Cost function to parallelizepublic ParallelizedCostFunctionContainer(ParallelizableCostFunction costFunction, java.util.concurrent.ThreadPoolExecutor threadPool)
threadPool - Thread pool used to parallelize the computationcostFunction - Cost function to parallelizepublic ParallelizedCostFunctionContainer clone()
AbstractCloneableSerializableObject class and
removes the exception that it throws. Its default behavior is to
automatically create a clone of the exact type of object that the
clone is called on and to copy all primitives but to keep all references,
which means it is a shallow copy.
Extensions of this class may want to override this method (but call
super.clone() to implement a "smart copy". That is, to target
the most common use case for creating a copy of the object. Because of
the default behavior being a shallow copy, extending classes only need
to handle fields that need to have a deeper copy (or those that need to
be reset). Some of the methods in ObjectUtil may be helpful in
implementing a custom clone method.
Note: The contract of this method is that you must use
super.clone() as the basis for your implementation.clone in interface CostFunction<Evaluator<? super Vector,? extends Vector>,java.util.Collection<? extends InputOutputPair<? extends Vector,Vector>>>clone in interface CloneableSerializableclone in class AbstractSupervisedCostFunction<Vector,Vector>public ParallelizableCostFunction getCostFunction()
public void setCostFunction(ParallelizableCostFunction costFunction)
costFunction - Cost function to parallelizeprotected void createPartitions()
public void setCostParameters(java.util.Collection<? extends InputOutputPair<? extends Vector,Vector>> costParameters)
CostFunctionsetCostParameters in interface CostFunction<Evaluator<? super Vector,? extends Vector>,java.util.Collection<? extends InputOutputPair<? extends Vector,Vector>>>setCostParameters in class AbstractSupervisedCostFunction<Vector,Vector>costParameters - The parameters of the cost function.public java.lang.Double evaluate(Evaluator<? super Vector,? extends Vector> evaluator)
CostFunctionevaluate in interface Evaluator<Evaluator<? super Vector,? extends Vector>,java.lang.Double>evaluate in interface CostFunction<Evaluator<? super Vector,? extends Vector>,java.util.Collection<? extends InputOutputPair<? extends Vector,Vector>>>evaluate in class AbstractSupervisedCostFunction<Vector,Vector>evaluator - The object to evaluate.public java.lang.Double evaluatePerformance(java.util.Collection<? extends TargetEstimatePair<? extends Vector,? extends Vector>> data)
SupervisedPerformanceEvaluatorevaluatePerformance in interface SupervisedPerformanceEvaluator<Vector,Vector,Vector,java.lang.Double>evaluatePerformance in class AbstractSupervisedCostFunction<Vector,Vector>data - The target-estimate pairs to use to evaluate performance.public Vector computeParameterGradient(GradientDescendable function)
DifferentiableCostFunctioncomputeParameterGradient in interface DifferentiableCostFunctionfunction - The object to differentiate.public java.util.concurrent.ThreadPoolExecutor getThreadPool()
ParallelAlgorithmgetThreadPool in interface ParallelAlgorithmpublic void setThreadPool(java.util.concurrent.ThreadPoolExecutor threadPool)
ParallelAlgorithmsetThreadPool in interface ParallelAlgorithmthreadPool - Thread pool used for parallelization.public int getNumThreads()
ParallelAlgorithmgetNumThreads in interface ParallelAlgorithmprotected void createThreadPool()