Spaces:
Runtime error
Runtime error
Al John Lexter Lozano
Initial commit-make model as importable module and add simple gradio interface
ead2dcb
#!/usr/bin/python | |
import math | |
import pickle | |
accepted_chars = 'abcdefghijklmnopqrstuvwxyz ' | |
pos = dict([(char, idx) for idx, char in enumerate(accepted_chars)]) | |
def normalize(line): | |
""" Return only the subset of chars from accepted_chars. | |
This helps keep the model relatively small by ignoring punctuation, | |
infrequenty symbols, etc. """ | |
return [c.lower() for c in line if c.lower() in accepted_chars] | |
def ngram(n, l): | |
""" Return all n grams from l after normalizing """ | |
filtered = normalize(l) | |
for start in range(0, len(filtered) - n + 1): | |
yield ''.join(filtered[start:start + n]) | |
def train(): | |
""" Write a simple model as a pickle file """ | |
k = len(accepted_chars) | |
# Assume we have seen 10 of each character pair. This acts as a kind of | |
# prior or smoothing factor. This way, if we see a character transition | |
# live that we've never observed in the past, we won't assume the entire | |
# string has 0 probability. | |
counts = [[10 for i in xrange(k)] for i in xrange(k)] | |
# Count transitions from big text file, taken | |
# from http://norvig.com/spell-correct.html | |
for line in open('big.txt'): | |
for a, b in ngram(2, line): | |
counts[pos[a]][pos[b]] += 1 | |
# Normalize the counts so that they become log probabilities. | |
# We use log probabilities rather than straight probabilities to avoid | |
# numeric underflow issues with long texts. | |
# This contains a justification: | |
# http://squarecog.wordpress.com/2009/01/10/dealing-with-underflow-in-joint-probability-calculations/ | |
for i, row in enumerate(counts): | |
s = float(sum(row)) | |
for j in xrange(len(row)): | |
row[j] = math.log(row[j] / s) | |
# Find the probability of generating a few arbitrarily choosen good and | |
# bad phrases. | |
good_probs = [avg_transition_prob(l, counts) for l in open('good.txt')] | |
bad_probs = [avg_transition_prob(l, counts) for l in open('bad.txt')] | |
# Assert that we actually are capable of detecting the junk. | |
assert min(good_probs) > max(bad_probs) | |
# And pick a threshold halfway between the worst good and best bad inputs. | |
thresh = (min(good_probs) + max(bad_probs)) / 2 | |
pickle.dump({'mat': counts, 'thresh': thresh}, open('gib_model.pki', 'wb')) | |
def avg_transition_prob(l, log_prob_mat): | |
""" Return the average transition prob from l through log_prob_mat. """ | |
log_prob = 0.0 | |
transition_ct = 0 | |
for a, b in ngram(2, l): | |
log_prob += log_prob_mat[pos[a]][pos[b]] | |
transition_ct += 1 | |
# The exponentiation translates from log probs to probs. | |
return math.exp(log_prob / (transition_ct or 1)) | |
if __name__ == '__main__': | |
train() | |