-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fa6c367
commit 9e3cfe5
Showing
5 changed files
with
2,084 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
from openai import OpenAI | ||
import fileinput | ||
import sys | ||
|
||
|
||
def _extract_relevant_info(text): | ||
start_index = text.find('"""') | ||
if start_index != -1: | ||
end_index = text.find('"""', start_index + 1) | ||
extracted_text = text[start_index : end_index + 3] | ||
return extracted_text | ||
else: | ||
return None | ||
|
||
|
||
def generate_docstring(file_str, fn_name, key): | ||
|
||
"""This function generates a docstring for a given function by extracting relevant information from a template file. The `generate_docstring` function takes three arguments: | ||
- `file_str` (string): The file path of the target function. This is used to read the content of the file. | ||
- `fn_name` (string): The name of the target function. This is used to replace placeholders in the prompt template to create a specific prompt for the function. | ||
- `key` (string): The API key used to access the OpenAI GPT-3 API. | ||
The function first reads the content of the file specified by `file_str`. It then replaces placeholders in the prompt template with the actual content and function name. This template includes a default docstring structure with sections for describing the function, its parameters, return value, and examples. | ||
The prompt is then passed to the OpenAI GPT-3 API as a system message to generate the completion. The completion represents the generated docstring for the function. Finally, the relevant information is extracted from the completion and formatted into a correctly indented docstring. | ||
The generated docstring follows the format: | ||
--- | ||
Description paragraph here. | ||
Parameters | ||
---------- | ||
first : array_like | ||
the 1st param description name `first` | ||
second : | ||
the 2nd param description | ||
third : {'value', 'other'}, optional description | ||
the 3rd param, by default 'value' | ||
Returns | ||
------- | ||
ret | ||
description of return | ||
Raises | ||
------ | ||
KeyError | ||
when a key error | ||
OtherError | ||
when an other error | ||
Examples | ||
-------- | ||
Some minimal examples of using the function should go here, something like: | ||
>>> x = torch.add(torch.tensor([1, 2]), torch.tensor([3, 4])) | ||
>>> x | ||
tensor([4, 6]) | ||
>>> x.shape | ||
torch.Size([2]) | ||
>>> seq_list = seqential_list(1, 5) | ||
>>> seq_list | ||
[1, 2, 3, 4] | ||
>>> len(seq_list) | ||
4 | ||
>>> min(seq_list) | ||
1 | ||
""" | ||
with open(file_str, "r") as f: | ||
content = f.read() | ||
#OpenAI key to be added | ||
client = OpenAI( | ||
api_key=key | ||
) | ||
|
||
prompt_file = open("../resources/prompt.txt", "r") | ||
prompt = prompt_file.read() | ||
prompt = prompt.replace("[fn_name]", fn_name) | ||
prompt = prompt.replace("[file]", content) | ||
|
||
# TODO: replace this default docstring template with one generated by Claude? | ||
# (so it can support any language, not just Python) | ||
template_file = open("resources/templates/function_templates/python_template", "r") | ||
docstring_template = template_file.read() | ||
prompt = prompt.replace("[docstring_example]", docstring_template) | ||
|
||
|
||
completion = client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
messages=[ | ||
{"role": "system", "content": prompt} | ||
] | ||
) | ||
docstring = _extract_relevant_info(completion.choices[0].message.content) | ||
docstring = docstring.replace("\n", "\n ") | ||
return " " + docstring + "\n" | ||
|
||
|
||
def add_docstring(key): | ||
|
||
""" | ||
The 'add_docstring' function is responsible for parsing a diff text file and generating docstrings for functions that do not currently have one. | ||
Parameters | ||
---------- | ||
key : str | ||
The key to be used for generating the docstrings. This key can be used to customize the content of the generated docstrings. | ||
Returns | ||
------- | ||
file_fns_without_docstring : dict | ||
A dictionary containing the filenames as keys and another dictionary as value. The inner dictionary contains the function names as keys and the generated docstrings as values. | ||
Raises | ||
------ | ||
None | ||
Examples | ||
-------- | ||
Here are some examples of using the function: | ||
>>> add_docstring("example_key") | ||
{'file1.py': {'function1': 'Generated docstring for function1', 'function2': 'Generated docstring for function2'}, 'file2.py': {'function3': 'Generated docstring for function3'}} | ||
>>> add_docstring("another_key") | ||
{'file1.py': {'function1': 'Generated docstring for function1', 'function2': 'Generated docstring for function2'}, 'file2.py': {'function3': 'Generated docstring for function3'}} | ||
>>> add_docstring("yet_another_key") | ||
{'file1.py': {'function1': 'Generated docstring for function1', 'function2': 'Generated docstring for function2'}, 'file2.py': {'function3': 'Generated docstring for function3'}} | ||
""" | ||
|
||
# diff text file all strings, parse the file to only fetch statements with additions | ||
# changed file names | ||
with open("diff.txt", "+rb") as f: | ||
# intelligent regex | ||
content = f.readlines() | ||
file_fns_without_docstring = {} | ||
filename = " " | ||
contains_docstring = False | ||
in_func = False | ||
for i, line in enumerate(content): | ||
line = line.decode("utf-8") # Decode the bytes to a string | ||
if line.startswith("+++"): | ||
start_index = line.find("/") | ||
if start_index != -1: | ||
filename = line[start_index + 1 :].rstrip("\n") | ||
fns_without_docstring = {} | ||
if line.replace(" ", "").startswith("+def"): | ||
in_func = True | ||
func_name = line.replace(" ", "").split("+def")[1].split("(")[0] | ||
# regex to check if there exists a docstring | ||
if ( | ||
line.replace(" ", "") == "+\n" | ||
or line.replace(" ", "") == "\n" | ||
or i == len(content) - 1 | ||
): | ||
if in_func and not contains_docstring: | ||
fns_without_docstring[func_name] = generate_docstring( | ||
filename, func_name, key | ||
) | ||
if filename not in file_fns_without_docstring: | ||
file_fns_without_docstring[filename] = {} | ||
file_fns_without_docstring[filename] = fns_without_docstring | ||
in_func = False | ||
contains_docstring = False | ||
func_name = "" | ||
if '"""' in line: | ||
contains_docstring = True | ||
return file_fns_without_docstring | ||
|
||
|
||
def merge_docstring(file_fns_without_docstring): | ||
"""Merges generated docstrings into the original files. | ||
This function takes the dictionary of filenames and functions without docstrings, | ||
iterates through each file, and inserts the generated docstring at the appropriate place. | ||
Parameters | ||
---------- | ||
file_fns_without_docstring : dict | ||
A dictionary with filenames as keys, and dictionaries as values. The inner dict has | ||
function names without docstrings as keys, and the generated docstring as values. | ||
Returns | ||
------- | ||
None | ||
The function edits the files in place. | ||
""" | ||
for filename, fns_without_docstring in file_fns_without_docstring.items(): | ||
with open(filename, "+rb") as f: | ||
content = f.readlines() | ||
fn_wo_doc = False | ||
current_docstring = "" | ||
docstring_placement = {} | ||
for i, line in enumerate(content): | ||
line = line.decode("utf-8") | ||
|
||
# if a fn without docstring is defined on this line | ||
if ( | ||
any([fn_name in line for fn_name in fns_without_docstring.keys()]) | ||
and "def" in line | ||
): | ||
fn_wo_doc = True | ||
for fn_name in fns_without_docstring.keys(): | ||
if fn_name in line: | ||
current_docstring = fns_without_docstring[fn_name] | ||
break | ||
if ("):" in line or ") ->" in line) and fn_wo_doc: | ||
docstring_placement[i + 2] = current_docstring | ||
current_docstring = "" | ||
fn_wo_doc = False | ||
|
||
# sort docstring placements | ||
docstring_placement = dict(sorted(docstring_placement.items())) | ||
|
||
with fileinput.input(files=(filename,), inplace=True) as file: | ||
for line_num, line in enumerate(file, start=1): | ||
# Check if this line should have content added | ||
if line_num in docstring_placement: | ||
content_to_add = docstring_placement[line_num] | ||
print(content_to_add, end="") | ||
print(line, end="") | ||
|
||
|
||
if __name__ == "__main__": | ||
key = sys.argv[1] | ||
docstring_dict = add_docstring(key) | ||
merge_docstring(docstring_dict) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
def add_docstring(key): | ||
# diff text file all strings, parse the file to only fetch statements with additions | ||
# changed file names | ||
with open("diff.txt", "+rb") as f: | ||
# intelligent regex | ||
content = f.readlines() | ||
file_fns_without_docstring = {} | ||
filename = " " | ||
contains_docstring = False | ||
in_func = False | ||
for i, line in enumerate(content): | ||
line = line.decode("utf-8") # Decode the bytes to a string | ||
if line.startswith("+++"): | ||
start_index = line.find("/") | ||
if start_index != -1: | ||
filename = line[start_index + 1 :].rstrip("\n") | ||
fns_without_docstring = {} | ||
if line.replace(" ", "").startswith("+def"): | ||
in_func = True | ||
func_name = line.replace(" ", "").split("+def")[1].split("(")[0] | ||
# regex to check if there exists a docstring | ||
if ( | ||
line.replace(" ", "") == "+\n" | ||
or line.replace(" ", "") == "\n" | ||
or i == len(content) - 1 | ||
): | ||
if in_func and not contains_docstring: | ||
fns_without_docstring[func_name] = generate_docstring( | ||
filename, func_name, key | ||
) | ||
if filename not in file_fns_without_docstring: | ||
file_fns_without_docstring[filename] = {} | ||
file_fns_without_docstring[filename] = fns_without_docstring | ||
in_func = False | ||
contains_docstring = False | ||
func_name = "" | ||
if '"""' in line: | ||
contains_docstring = True | ||
return file_fns_without_docstring |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from openai import OpenAI | ||
client = OpenAI( | ||
api_key="Add key" | ||
) | ||
|
||
with open("../test_scripts/demo3.py", "r") as f: | ||
content = f.read() | ||
|
||
prompt_file = open("../resources/prompt.txt", "r") | ||
prompt = prompt_file.read() | ||
prompt = prompt.replace("[fn_name]", "add_docstring") | ||
prompt = prompt.replace("[file]", content) | ||
|
||
# TODO: replace this default docstring template with one generated by Claude? | ||
# (so it can support any language, not just Python) | ||
template_file = open("../resources/templates/function_templates/python_template.txt", "r") | ||
docstring_template = template_file.read() | ||
prompt = prompt.replace("[docstring_example]", docstring_template) | ||
|
||
|
||
completion = client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
messages=[ | ||
{"role": "system", "content": prompt} | ||
] | ||
) | ||
|
||
print(completion.choices[0].message.content) | ||
|
||
|
Oops, something went wrong.