public class ParallelizedCostFunctionContainer extends AbstractSupervisedCostFunction<Vector,Vector> implements DifferentiableCostFunction, ParallelAlgorithm
Modifier and Type | Class and Description |
---|---|
protected static class |
ParallelizedCostFunctionContainer.SubCostEvaluate
Callable task for the evaluate() method.
|
protected static class |
ParallelizedCostFunctionContainer.SubCostGradient
Callable task for the computeGradient() method
|
Constructor and Description |
---|
ParallelizedCostFunctionContainer()
Default constructor for ParallelizedCostFunctionContainer.
|
ParallelizedCostFunctionContainer(ParallelizableCostFunction costFunction)
Creates a new instance of ParallelizedCostFunctionContainer
|
ParallelizedCostFunctionContainer(ParallelizableCostFunction costFunction,
java.util.concurrent.ThreadPoolExecutor threadPool)
Creates a new instance of ParallelizedCostFunctionContainer
|
Modifier and Type | Method and Description |
---|---|
ParallelizedCostFunctionContainer |
clone()
This makes public the clone method on the
Object class and
removes the exception that it throws. |
Vector |
computeParameterGradient(GradientDescendable function)
Differentiates function with respect to its parameters.
|
protected void |
createPartitions()
Splits the data across the numComponents cost functions
|
protected void |
createThreadPool()
Creates the thread pool using the Foundry's global thread pool.
|
java.lang.Double |
evaluate(Evaluator<? super Vector,? extends Vector> evaluator)
Computes the cost of the given target.
|
java.lang.Double |
evaluatePerformance(java.util.Collection<? extends TargetEstimatePair<? extends Vector,? extends Vector>> data)
Evaluates the performance accuracy of the given estimates against the
given targets.
|
ParallelizableCostFunction |
getCostFunction()
Getter for costFunction
|
int |
getNumThreads()
Gets the number of threads in the thread pool.
|
java.util.concurrent.ThreadPoolExecutor |
getThreadPool()
Gets the thread pool for the algorithm to use.
|
void |
setCostFunction(ParallelizableCostFunction costFunction)
Setter for costFunction
|
void |
setCostParameters(java.util.Collection<? extends InputOutputPair<? extends Vector,Vector>> costParameters)
Sets the parameters of the cost function used to evaluate the cost of
a target.
|
void |
setThreadPool(java.util.concurrent.ThreadPoolExecutor threadPool)
Sets the thread pool for the algorithm to use.
|
getCostParameters, summarize
evaluatePerformance
equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
getCostParameters
summarize
evaluatePerformance
public ParallelizedCostFunctionContainer()
public ParallelizedCostFunctionContainer(ParallelizableCostFunction costFunction)
costFunction
- Cost function to parallelizepublic ParallelizedCostFunctionContainer(ParallelizableCostFunction costFunction, java.util.concurrent.ThreadPoolExecutor threadPool)
threadPool
- Thread pool used to parallelize the computationcostFunction
- Cost function to parallelizepublic ParallelizedCostFunctionContainer clone()
AbstractCloneableSerializable
Object
class and
removes the exception that it throws. Its default behavior is to
automatically create a clone of the exact type of object that the
clone is called on and to copy all primitives but to keep all references,
which means it is a shallow copy.
Extensions of this class may want to override this method (but call
super.clone()
to implement a "smart copy". That is, to target
the most common use case for creating a copy of the object. Because of
the default behavior being a shallow copy, extending classes only need
to handle fields that need to have a deeper copy (or those that need to
be reset). Some of the methods in ObjectUtil
may be helpful in
implementing a custom clone method.
Note: The contract of this method is that you must use
super.clone()
as the basis for your implementation.clone
in interface CostFunction<Evaluator<? super Vector,? extends Vector>,java.util.Collection<? extends InputOutputPair<? extends Vector,Vector>>>
clone
in interface CloneableSerializable
clone
in class AbstractSupervisedCostFunction<Vector,Vector>
public ParallelizableCostFunction getCostFunction()
public void setCostFunction(ParallelizableCostFunction costFunction)
costFunction
- Cost function to parallelizeprotected void createPartitions()
public void setCostParameters(java.util.Collection<? extends InputOutputPair<? extends Vector,Vector>> costParameters)
CostFunction
setCostParameters
in interface CostFunction<Evaluator<? super Vector,? extends Vector>,java.util.Collection<? extends InputOutputPair<? extends Vector,Vector>>>
setCostParameters
in class AbstractSupervisedCostFunction<Vector,Vector>
costParameters
- The parameters of the cost function.public java.lang.Double evaluate(Evaluator<? super Vector,? extends Vector> evaluator)
CostFunction
evaluate
in interface Evaluator<Evaluator<? super Vector,? extends Vector>,java.lang.Double>
evaluate
in interface CostFunction<Evaluator<? super Vector,? extends Vector>,java.util.Collection<? extends InputOutputPair<? extends Vector,Vector>>>
evaluate
in class AbstractSupervisedCostFunction<Vector,Vector>
evaluator
- The object to evaluate.public java.lang.Double evaluatePerformance(java.util.Collection<? extends TargetEstimatePair<? extends Vector,? extends Vector>> data)
SupervisedPerformanceEvaluator
evaluatePerformance
in interface SupervisedPerformanceEvaluator<Vector,Vector,Vector,java.lang.Double>
evaluatePerformance
in class AbstractSupervisedCostFunction<Vector,Vector>
data
- The target-estimate pairs to use to evaluate performance.public Vector computeParameterGradient(GradientDescendable function)
DifferentiableCostFunction
computeParameterGradient
in interface DifferentiableCostFunction
function
- The object to differentiate.public java.util.concurrent.ThreadPoolExecutor getThreadPool()
ParallelAlgorithm
getThreadPool
in interface ParallelAlgorithm
public void setThreadPool(java.util.concurrent.ThreadPoolExecutor threadPool)
ParallelAlgorithm
setThreadPool
in interface ParallelAlgorithm
threadPool
- Thread pool used for parallelization.public int getNumThreads()
ParallelAlgorithm
getNumThreads
in interface ParallelAlgorithm
protected void createThreadPool()