Spaces:
Build error
Build error
from torch_grammar import GrammarSampler | |
from transformers.generation.logits_process import LogitsProcessor | |
from modules import shared | |
sampler = None | |
grammar = None | |
grammar_string = '' | |
class GrammarLogitsProcessor(LogitsProcessor): | |
def __init__(self, string): | |
global sampler, grammar, grammar_string | |
if string != grammar_string: | |
grammar_string = string | |
if string.strip() != '': | |
string = string.strip() + '\n' | |
sampler = GrammarSampler(string, 'root', shared.tokenizer) | |
else: | |
sampler = None | |
if sampler is not None: | |
grammar = sampler.logits_processor() | |
else: | |
grammar = None | |
def __call__(self, input_ids, scores): | |
if grammar is not None: | |
scores = grammar(input_ids, scores) | |
return scores | |