File size: 1,659 Bytes
646ac2a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import json
import mlxu
from EasyLM.serving import LMClient
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
input_file='',
output_file='',
prefix_field='prefix',
text_field='text',
until_field='until',
eval_type='loglikelihood',
lm_client=LMClient.get_default_config(),
)
def main(argv):
lm_client = LMClient(FLAGS.lm_client)
with mlxu.open_file(FLAGS.input_file, 'r') as fin:
input_data = json.load(fin)
if FLAGS.eval_type == 'loglikelihood':
prefix = input_data[FLAGS.prefix_field]
text = input_data[FLAGS.text_field]
loglikelihoods, is_greedys = lm_client.loglikelihood(prefix, text)
output_data = {
'loglikelihood': loglikelihoods,
'is_greedy': is_greedys,
}
elif FLAGS.eval_type == 'loglikelihood_rolling':
text = input_data[FLAGS.text_field]
loglikelihoods, is_greedys = lm_client.loglikelihood_rolling(text)
output_data = {
'loglikelihood': loglikelihoods,
'is_greedy': is_greedys,
}
elif FLAGS.eval_type == 'greedy_until':
prefix = input_data[FLAGS.prefix_field]
until = input_data[FLAGS.until_field]
output_data = {'output_text': lm_client.greedy_until(prefix, until)}
elif FLAGS.eval_type == 'generate':
prefix = input_data[FLAGS.prefix_field]
output_data = {'output_text': lm_client.generate(prefix)}
else:
raise ValueError(f'Unknown eval_type: {FLAGS.eval_type}')
with mlxu.open_file(FLAGS.output_file, 'w') as fout:
json.dump(output_data, fout)
if __name__ == "__main__":
mlxu.run(main)
|