Basic info
model based Salesforce/codegen-350M-mono
fine-tuned with data codeparrot/github-code-clean
data filter by python
Usage
from transformers import AutoTokenizer, AutoModelForCausalLM
model_type = 'kdf/python-docstring-generation'
tokenizer = AutoTokenizer.from_pretrained(model_type)
model = AutoModelForCausalLM.from_pretrained(model_type)
inputs = tokenizer('''<|endoftext|>
def load_excel(path):
return pd.read_excel(path)
# docstring
"""''', return_tensors='pt')
doc_max_length = 128
generated_ids = model.generate(
**inputs,
max_length=inputs.input_ids.shape[1] + doc_max_length,
do_sample=False,
return_dict_in_generate=True,
num_return_sequences=1,
output_scores=True,
pad_token_id=50256,
eos_token_id=50256 # <|endoftext|>
)
ret = tokenizer.decode(generated_ids.sequences[0], skip_special_tokens=False)
print(ret)
Prompt
You could give model a style or a specific language, for example:
inputs = tokenizer('''<|endoftext|>
def add(a, b):
return a + b
# docstring
"""
Calculate numbers add.
Args:
a: the first number to add
b: the second number to add
Return:
The result of a + b
"""
<|endoftext|>
def load_excel(path):
return pd.read_excel(path)
# docstring
"""''', return_tensors='pt')
doc_max_length = 128
generated_ids = model.generate(
**inputs,
max_length=inputs.input_ids.shape[1] + doc_max_length,
do_sample=False,
return_dict_in_generate=True,
num_return_sequences=1,
output_scores=True,
pad_token_id=50256,
eos_token_id=50256 # <|endoftext|>
)
ret = tokenizer.decode(generated_ids.sequences[0], skip_special_tokens=False)
print(ret)
inputs = tokenizer('''<|endoftext|>
def add(a, b):
return a + b
# docstring
"""
计算数字相加
Args:
a: 第一个加数
b: 第二个加数
Return:
相加的结果
"""
<|endoftext|>
def load_excel(path):
return pd.read_excel(path)
# docstring
"""''', return_tensors='pt')
doc_max_length = 128
generated_ids = model.generate(
**inputs,
max_length=inputs.input_ids.shape[1] + doc_max_length,
do_sample=False,
return_dict_in_generate=True,
num_return_sequences=1,
output_scores=True,
pad_token_id=50256,
eos_token_id=50256 # <|endoftext|>
)
ret = tokenizer.decode(generated_ids.sequences[0], skip_special_tokens=False)
print(ret)
- Downloads last month
- 171
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.