@PublicationReference(title="Factorization Machines",author="Steffen Rendle",year=2010,type=Conference,publication="Proceedings of the 10th IEEE International Conference on Data Mining (ICDM)",url="http://www.inf.uni-konstanz.de/~rendle/pdf/Rendle2010FM.pdf") @PublicationReference(title="Factorization Machines with libFM",author="Steffen Rendle",year=2012,type=Journal,publication="ACM Transactions on Intelligent Systems Technology",url="http://www.csie.ntu.edu.tw/~b97053/paper/Factorization%20Machines%20with%20libFM.pdf",notes="Algorithm 1: Stochastic Gradient Descent (SGD)") public class FactorizationMachineStochasticGradient extends AbstractFactorizationMachineLearner implements MeasurablePerformanceAlgorithm
FactorizationMachine
,
Serialized FormModifier and Type | Field and Description |
---|---|
protected java.util.ArrayList<? extends InputOutputPair<? extends Vector,java.lang.Double>> |
dataList
The input data represented as a list for fast access.
|
static double |
DEFAULT_LEARNING_RATE
The default learning rate is 0.001.
|
protected double |
learningRate
The learning rate for the algorithm.
|
protected double |
totalChange
The total change in factorization machine parameters for the current
iteration.
|
protected double |
totalError
The total error for the current iteration.
|
biasEnabled, biasRegularization, DEFAULT_BIAS_ENABLED, DEFAULT_BIAS_REGULARIZATION, DEFAULT_FACTOR_COUNT, DEFAULT_FACTOR_REGULARIZATION, DEFAULT_MAX_ITERATIONS, DEFAULT_SEED_SCALE, DEFAULT_WEIGHT_REGULARIZATION, DEFAULT_WEIGHTS_ENABLED, dimensionality, factorCount, factorRegularization, random, result, seedScale, weightRegularization, weightsEnabled
data, keepGoing
maxIterations
DEFAULT_ITERATION, iteration
Constructor and Description |
---|
FactorizationMachineStochasticGradient()
Creates a new
FactorizationMachineStochasticGradient with
default parameters. |
FactorizationMachineStochasticGradient(int factorCount,
double learningRate,
double biasRegularization,
double weightRegularization,
double factorRegularization,
double seedScale,
int maxIterations,
java.util.Random random)
Creates a new
AbstractFactorizationMachineLearner . |
Modifier and Type | Method and Description |
---|---|
protected void |
cleanupAlgorithm()
Called to clean up the learning algorithm's state after learning has
finished.
|
double |
getLearningRate()
Gets the learning rate.
|
double |
getObjective()
Gets the total objective, which is the mean squared error plus the
regularization terms.
|
NamedValue<? extends java.lang.Number> |
getPerformance()
Gets the name-value pair that describes the current performance of the
algorithm.
|
double |
getRegularizationPenalty()
Gets the regularization penalty term for the current result.
|
double |
getTotalChange()
Gets the total change from the current iteration.
|
double |
getTotalError()
Gets the total squared error from the current iteration.
|
protected boolean |
initializeAlgorithm()
Called to initialize the learning algorithm's state based on the
data that is stored in the data field.
|
void |
setLearningRate(double learningRate)
Gets the learning rate.
|
protected boolean |
step()
Called to take a single step of the learning algorithm.
|
protected void |
update(InputOutputPair<? extends Vector,java.lang.Double> example)
Performs a single update of step of the stochastic gradient descent
by updating according to the given example.
|
getBiasRegularization, getFactorCount, getFactorRegularization, getRandom, getResult, getSeedScale, getWeightRegularization, isBiasEnabled, isFactorsEnabled, isWeightsEnabled, setBiasEnabled, setBiasRegularization, setFactorCount, setFactorRegularization, setRandom, setSeedScale, setWeightRegularization, setWeightsEnabled
clone, getData, getKeepGoing, learn, setData, setKeepGoing, stop
getMaxIterations, isResultValid, setMaxIterations
addIterativeAlgorithmListener, fireAlgorithmEnded, fireAlgorithmStarted, fireStepEnded, fireStepStarted, getIteration, getListeners, removeIterativeAlgorithmListener, setIteration, setListeners
equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
learn
clone
getMaxIterations, setMaxIterations
addIterativeAlgorithmListener, getIteration, removeIterativeAlgorithmListener
isResultValid
public static final double DEFAULT_LEARNING_RATE
protected double learningRate
protected transient java.util.ArrayList<? extends InputOutputPair<? extends Vector,java.lang.Double>> dataList
protected transient double totalError
protected transient double totalChange
public FactorizationMachineStochasticGradient()
FactorizationMachineStochasticGradient
with
default parameters.public FactorizationMachineStochasticGradient(int factorCount, double learningRate, double biasRegularization, double weightRegularization, double factorRegularization, double seedScale, int maxIterations, java.util.Random random)
AbstractFactorizationMachineLearner
.factorCount
- The number of factors to use. Zero means no factors. Cannot be
negative.learningRate
- The learning rate. Must be positive.biasRegularization
- The regularization term for the bias. Cannot be negative.weightRegularization
- The regularization term for the linear weights. Cannot be negative.factorRegularization
- The regularization term for the factor matrix. Cannot be negative.seedScale
- The random initialization scale for the factors.
Multiplied by a random Gaussian to initialize each factor value.
Cannot be negative.maxIterations
- The maximum number of iterations for the algorithm to run. Cannot
be negative.random
- The random number generator.protected boolean initializeAlgorithm()
AbstractAnytimeBatchLearner
initializeAlgorithm
in class AbstractFactorizationMachineLearner
protected boolean step()
AbstractAnytimeBatchLearner
step
in class AbstractAnytimeBatchLearner<java.util.Collection<? extends InputOutputPair<? extends Vector,java.lang.Double>>,FactorizationMachine>
protected void update(InputOutputPair<? extends Vector,java.lang.Double> example)
example
- The example to do a stochastic gradient step for.protected void cleanupAlgorithm()
AbstractAnytimeBatchLearner
cleanupAlgorithm
in class AbstractAnytimeBatchLearner<java.util.Collection<? extends InputOutputPair<? extends Vector,java.lang.Double>>,FactorizationMachine>
public double getTotalChange()
public double getTotalError()
public double getRegularizationPenalty()
public double getObjective()
public NamedValue<? extends java.lang.Number> getPerformance()
MeasurablePerformanceAlgorithm
getPerformance
in interface MeasurablePerformanceAlgorithm
public double getLearningRate()
public void setLearningRate(double learningRate)
learningRate
- The learning rate. Must be positive.