|
#include "tune_weights.hh" |
|
|
|
#include "tune_derivatives.hh" |
|
#include "tune_instances.hh" |
|
|
|
#pragma GCC diagnostic push |
|
#pragma GCC diagnostic ignored "-Wpragmas" |
|
#pragma GCC diagnostic ignored "-Wunused-local-typedefs" |
|
#include <Eigen/Dense> |
|
#pragma GCC diagnostic pop |
|
#include <boost/program_options.hpp> |
|
|
|
#include <iostream> |
|
|
|
namespace lm { namespace interpolate { |
|
void TuneWeights(int tune_file, const std::vector<StringPiece> &model_names, const InstancesConfig &config, std::vector<float> &weights_out) { |
|
Instances instances(tune_file, model_names, config); |
|
Vector weights = Vector::Constant(model_names.size(), 1.0 / model_names.size()); |
|
Vector gradient; |
|
Matrix hessian; |
|
for (std::size_t iteration = 0; iteration < 10 ; ++iteration) { |
|
std::cerr << "Iteration " << iteration << ": weights ="; |
|
for (Vector::Index i = 0; i < weights.rows(); ++i) { |
|
std::cerr << ' ' << weights(i); |
|
} |
|
std::cerr << std::endl; |
|
std::cerr << "Perplexity = " << Derivatives(instances, weights, gradient, hessian) << std::endl; |
|
|
|
weights -= 0.7 * hessian.inverse() * gradient; |
|
} |
|
weights_out.assign(weights.data(), weights.data() + weights.size()); |
|
} |
|
}} |
|
|