CategoryType
- The type of output categories. Can be any type that has a valid
equals and hashCode method.public class OnlineMultiPerceptron<CategoryType> extends AbstractBatchAndIncrementalLearner<InputOutputPair<? extends Vectorizable,CategoryType>,LinearMultiCategorizer<CategoryType>> implements VectorFactoryContainer
Modifier and Type | Class and Description |
---|---|
static class |
OnlineMultiPerceptron.ProportionalUpdate<CategoryType>
Variant of a multi-category Perceptron that performs a proportional
weight update on all categories that are scored higher than the true
category such that the weights sum to 1.0 and are proportional how much
larger the score was for each incorrect category than the true category.
|
static class |
OnlineMultiPerceptron.UniformUpdate<CategoryType>
Variant of a multi-category Perceptron that performs a uniform weight
update on all categories that are scored higher than the true category
such that the weights are equal and sum to -1.
|
Modifier and Type | Field and Description |
---|---|
static double |
DEFAULT_MIN_MARGIN
The default minimum margin is 0.0.
|
protected double |
minMargin
The minimum margin to enforce.
|
protected VectorFactory<?> |
vectorFactory
The factory to create weight vectors.
|
Constructor and Description |
---|
OnlineMultiPerceptron()
Creates a new
OnlineMultiPerceptron . |
OnlineMultiPerceptron(double minMargin)
Creates a new
OnlineMultiPerceptron with the
given minimum margin. |
OnlineMultiPerceptron(double minMargin,
VectorFactory<?> vectorFactory)
Creates a new
OnlineMultiPerceptron with the
given minimum margin and backing vector factory. |
Modifier and Type | Method and Description |
---|---|
LinearMultiCategorizer<CategoryType> |
createInitialLearnedObject()
Creates a new initial learned object, before any data is given.
|
double |
getMinMargin()
Gets the minimum margin to enforce.
|
VectorFactory<?> |
getVectorFactory()
Gets the VectorFactory used to create the weight vector.
|
void |
setMinMargin(double minMargin)
Gets the minimum margin to enforce.
|
void |
setVectorFactory(VectorFactory<?> vectorFactory)
Sets the VectorFactory used to create the weight vector.
|
void |
update(LinearMultiCategorizer<CategoryType> target,
InputOutputPair<? extends Vectorizable,CategoryType> example)
The
update method updates an object of ResultType using
the given new data of type DataType , using some form of
"learning" algorithm. |
clone, learn, learn, update
public static final double DEFAULT_MIN_MARGIN
protected double minMargin
protected VectorFactory<?> vectorFactory
public OnlineMultiPerceptron()
OnlineMultiPerceptron
.public OnlineMultiPerceptron(double minMargin)
OnlineMultiPerceptron
with the
given minimum margin.minMargin
- The minimum margin to consider an example correct.public OnlineMultiPerceptron(double minMargin, VectorFactory<?> vectorFactory)
OnlineMultiPerceptron
with the
given minimum margin and backing vector factory.minMargin
- The minimum margin to consider an example correct.vectorFactory
- The vector factory used to create the weight vectors.public LinearMultiCategorizer<CategoryType> createInitialLearnedObject()
IncrementalLearner
createInitialLearnedObject
in interface IncrementalLearner<InputOutputPair<? extends Vectorizable,CategoryType>,LinearMultiCategorizer<CategoryType>>
public void update(LinearMultiCategorizer<CategoryType> target, InputOutputPair<? extends Vectorizable,CategoryType> example)
IncrementalLearner
update
method updates an object of ResultType
using
the given new data of type DataType
, using some form of
"learning" algorithm.update
in interface IncrementalLearner<InputOutputPair<? extends Vectorizable,CategoryType>,LinearMultiCategorizer<CategoryType>>
target
- The object to update.example
- The new data for the learning algorithm to use to update
the object.public double getMinMargin()
public void setMinMargin(double minMargin)
minMargin
- The minimum margin. Cannot be negative.public VectorFactory<?> getVectorFactory()
getVectorFactory
in interface VectorFactoryContainer
public void setVectorFactory(VectorFactory<?> vectorFactory)
vectorFactory
- The VectorFactory used to create the weight vector.