๐ Large Language Models are Bad at Refactoring Code
William Zeng - October 31th, 2023
Refactoring is really hard
If youโve ever refactored a thousand line Python function, you know that itโs a huge pain and extremely error prone.
When extracting a function out of a larger function or class, we have to:
- Gather the correct input variables
- Ensure the state is correctly returned
- Handle any side effects
Weโve been looking for the best use cases to focus on for Sweep, and one of these has been writing and executing unit tests.
This article is about refactoring, but donโt worry, we didnโt forget about testing ;).
The problem is that thousand line Python functions are not only hard to read, they're hard to test. And personally, I'd rather have tests for my messy orchestration function than tests for a small util that rarely changes.
We decided that weโd set out on a simple task for the Halloween weekend: get Sweep to perfectly refactor our own code.
Language models can't do this
If youโve read our previous blog on why modifying code with GPT4 is hard (opens in a new tab) youโll know that a refactor is even harder.
GPT4 has to perfectly extract all of the inputs to a 50-75
line snippet of code out of a larger snippet (hundreds of lines?).
To do this using pure text, a language model needs to delete exactly the right span of lines, copy these perfectly, and then handle the function call as well as the imports.
When generating the copied lines, it's prone to continue generating code based on the original file. Say we want GPT4 to generate a function wrapping the extracted code on lines 2-3.
1: def print_array(a):
<<<<<<< EXTRACT
2: for item in a:
3: print(item)
>>>>>>>
4: print(a)
5: return a
GPT4 will extract this, but when it generates the new function's signature, you might get:
1: def print_array(a):
2: print_items(a)
3: print(a) # copied twice by accident!
4: return a
5:
6: def print_items(a):
7: for item in a:
8: print(item)
9: print(a) # copied twice by accident!
Itโs possible to rely on a unit test feedback loop to fix this (write the new code, check the tests pass, and then fix anything broken), but itโs a bit too chaotic in practice. We might not have unit tests for the code, or the unit tests might pass but not cover this case(arguably worse!).
We wanted something deterministic, so we used Rope (opens in a new tab). Rope is a refactoring library for Python that uses the Python AST, allowing for refactors that work 100% of the time.
To use Rope, we can ask GPT4 to define sections of the code that it wants to extract. These should correspond to the request.
Say we want to extract the code snippet and metadata preparation code out of our get_vector_database
function.
Request:
Extract the code that prepares documents, metadatas, and ids into
a new function named `prepare_documents_metadata_ids`.
This function should take `snippets`, `cloned_repo`, and `files_to_scores` as parameters
and return `documents`, `metadatas`, and `ids`.
We have Sweep generate this response:
<contextual_request_analysis>
The user request is to refactor the `get_deeplake_vs_from_repo` function in the `vector_db.py` file by extracting certain sections of the code into separate functions. The sections to be extracted are as follows:
The code that prepares documents, metadatas, and ids. This section starts with the line `documents = []` and ends with the line `ids.append(f"{gh_file_path}:{snippet.start}:{snippet.end}")`.
</contextual_request_analysis>
<new_function_names>
"prepare_documents_metadata_ids"
</new_function_names>
<extractions>
<<<<<<< EXTRACT
documents = []
metadatas = []
ids = []
for snippet in snippets:
documents.append(snippet.get_snippet(add_ellipsis=False, add_lines=False))
metadata = {
"file_path": snippet.file_path[len(cloned_repo.cache_dir) + 1 :],
"start": snippet.start,
"end": snippet.end,
"score": files_to_scores[snippet.file_path],
}
metadatas.append(metadata)
gh_file_path = snippet.file_path[len("repo/") :]
ids.append(f"{gh_file_path}:{snippet.start}:{snippet.end}")
logger.info(f"Getting list of all files took {time.time() - start}")
logger.info(f"Received {len(documents)} documents from repository {repo_full_name}")
collection_name = parse_collection_name(repo_full_name)
>>>>>>>
</extractions>
Then we take the <<<<<<< EXTRACT >>>>>>>
section out of the <extractions>
tag and fuzzy match it with the original code.
Covering GPT4's weaknesses
This sounds great when GPT4 cooperates, but it doesn't work out of the box.
Fuzzy matching helps with forgetting
We have a two pointer-based scoring algorithm which finds the code span that is most similar to the GPT4 generated span.
GPT4 often overlooks comments, newlines, and logger statements when extracting sections. We added a section in our scoring algorithm that makes the score agnostic to the most commonly forgotten sections.
Here's a snippet:
elif (
t_line.strip() == ""
or t_line.strip().startswith("#")
or t_line.strip().startswith("//")
or t_line.strip().startswith("print")
or t_line.strip().startswith("logger")
or t_line.strip().startswith("console.")
):
# skip scoring or put a relatively high score
Tricking GPT4 to prevent hallucinations
GPT4 sometimes thinks it must extract the new code and generate the function when it only needs to copy the code verbatim. As a toy example, if it was supposed to extract
print(x)
x += 1
from a larger section of code, it would correctly extract but incorrectly add function signatures
def (x): # incorrectly added
print(x)
x += 1
return x # incorrectly added
even though we only want the code verbatim. Rope will later generate the function signatures and replace all occurrences.
We had to add a special case to our scoring algorithm to handle this. We first try matching without the first line, then try matching without the last line.
This algorithm does the following:
def (x):
print(x) delete first line -> print(x) delete last line -> print(x)
x += 1 x += 1 x += 1
return x return x
We also minimized the chance of this by modifying our prompt with a single word:
- Extract code verbatim from the snippets above. These snippets will be used to refactor the code according to the user request.
+ Extract code verbatim from the snippets above. These snippets will be used later to refactor the code according to the user request.
We hoped this would convince GPT4 that it doesn't need to handle converting the code to a function, and it seems to work!
Don't use the "Think step-by-step" prompt
Another issue occurs when GPT4 extracts a MASSIVE section from the original code. This is going to happen a lot, as refactor requests should(reasonably) happen in large files.
When GPT4 generates >400 lines of code to extract, it risks running past the token limit (I saw it generate >4.8k tokens for one of our files), or even forgetting to close the </extraction>
XML tag.
One trick we found that helps with this is to add "anchors" to the chain-of-thought(CoT) prompting.
Instead of the classic "Think step-by-step to solve the problem"
, our CoT prompt looks like this:
<contextual_request_analysis>
Analyze the user request and outline the first and last few lines of code that should be extracted.
</contextual_request_analysis>
and then our GPT4 response will contain this line:
This section starts with the line `documents = []` and ends with the line `ids.append(f"{gh_file_path}:{snippet.start}:{snippet.end}")`
This helps a lot. Our hypothesis is that when starting to extract the span, the transformer(underlying model architecture of GPT4) attends to and anchors the generation to the start documents = []
, and then does the same for the end ids.append(f"{gh_file_path}:{snippet.start}:{snippet.end}")
.
Side note: We could have GPT4 generate just the start and end lines, but generating the entire code snippet allows us to further disambiguate the output with fuzzy matching.
We use this strategy in almost all of our prompts, and it made them much more reliable. Task specific CoT with "anchors" seems to be the way to go!
Here's our final prompt:
# Instructions
Extract code verbatim from the snippets above. These snippets will be used later to refactor the code according to the user request.
* Choose specific and very informative names for these functions under new_function_name.
* We must copy the code verbatim, so any extra leading or trailing code will cause us to fail.
* Keep whitespace and comments.
* Use EXTRACT to isolate specific code segments from the current function and place them into new, separate functions.
Respond in the following format:
<contextual_request_analysis>
Analyze the user request and outline the first and last few lines of code that should be extracted.
</contextual_request_analysis>
<new_function_names>
"new_function_name"
...
</new_function_names>
<extractions>
<<<<<<< EXTRACT
first few lines to be extracted from original_code
...
last few lines to be extracted from original_code
>>>>>>>
...
</extractions>
The final result
Finally GPT4 generates multiple spans to refactor, as well as the function names for these methods. We can pass Rope the extracted span (after using GPT4 to make sure the span corresponds to the user request) and the AI generated function name to perfectly extract the code.
Check out this beautiful refactor:
def get_deeplake_vs_from_repo(
cloned_repo: ClonedRepo,
sweep_config: SweepConfig = SweepConfig(),
):
repo_full_name = cloned_repo.repo_full_name
repo = cloned_repo.repo
commits = repo.get_commits()
commit_hash = commits[0].sha
logger.info(f"Downloading repository and indexing for {repo_full_name}...")
start = time.time()
logger.info("Recursively getting list of files...")
blocked_dirs = get_blocked_dirs(repo)
sweep_config.exclude_dirs.extend(blocked_dirs)
<<<<<<<< EXTRACTED -> prepare_lexical_search_index
snippets, file_list = repo_to_chunks(cloned_repo.cache_dir, sweep_config)
logger.info(f"Found {len(snippets)} snippets in repository {repo_full_name}")
# prepare lexical search
index = prepare_index_from_snippets(
snippets, len_repo_cache_dir=len(cloned_repo.cache_dir) + 1
)
logger.print("Prepared index from snippets")
>>>>>>>
# scoring for vector search
<<<<<<<< EXTRACTED -> compute_vector_search_scores
files_to_scores = {}
score_factors = []
for file_path in tqdm(file_list):
if not redis_client:
score_factor = compute_score(
file_path[len(cloned_repo.cache_dir) + 1 :], cloned_repo.git_repo
)
score_factors.append(score_factor)
continue
cache_key = hash_sha256(file_path) + CACHE_VERSION
try:
cache_value = redis_client.get(cache_key)
except Exception as e:
logger.exception(e)
cache_value = None
if cache_value is not None:
score_factor = json.loads(cache_value)
score_factors.append(score_factor)
else:
score_factor = compute_score(
file_path[len(cloned_repo.cache_dir) + 1 :], cloned_repo.git_repo
)
score_factors.append(score_factor)
redis_client.set(cache_key, json.dumps(score_factor))
# compute all scores
all_scores = get_scores(score_factors)
files_to_scores = {
file_path: score for file_path, score in zip(file_list, all_scores)
}
logger.info(f"Found {len(file_list)} files in repository {repo_full_name}")
>>>>>>>
<<<<<<<< EXTRACTED -> prepare_documents_metadata_ids
documents = []
metadatas = []
ids = []
for snippet in snippets:
documents.append(snippet.get_snippet(add_ellipsis=False, add_lines=False))
metadata = {
"file_path": snippet.file_path[len(cloned_repo.cache_dir) + 1 :],
"start": snippet.start,
"end": snippet.end,
"score": files_to_scores[snippet.file_path],
}
metadatas.append(metadata)
gh_file_path = snippet.file_path[len("repo/") :]
ids.append(f"{gh_file_path}:{snippet.start}:{snippet.end}")
logger.info(f"Getting list of all files took {time.time() - start}")
logger.info(f"Received {len(documents)} documents from repository {repo_full_name}")
collection_name = parse_collection_name(repo_full_name)
>>>>>>>
deeplake_vs = compute_deeplake_vs(
collection_name, documents, ids, metadatas, commit_hash
)
return deeplake_vs, index, len(documents)
Sweep was able to cleanly split this large, messy function into completely modular sub-functions:
def get_deeplake_vs_from_repo(
cloned_repo: ClonedRepo,
sweep_config: SweepConfig = SweepConfig(),
):
repo_full_name = cloned_repo.repo_full_name
repo = cloned_repo.repo
commits = repo.get_commits()
commit_hash = commits[0].sha
logger.info(f"Downloading repository and indexing for {repo_full_name}...")
start = time.time()
logger.info("Recursively getting list of files...")
blocked_dirs = get_blocked_dirs(repo)
sweep_config.exclude_dirs.extend(blocked_dirs)
file_list, snippets, index = prepare_lexical_search_index(cloned_repo, sweep_config, repo_full_name)
# scoring for vector search
files_to_scores = compute_vector_search_scores(file_list, cloned_repo, repo_full_name)
collection_name, documents, ids, metadatas = prepare_documents_metadata_ids(snippets, cloned_repo, files_to_scores, start, repo_full_name)
deeplake_vs = compute_deeplake_vs(
collection_name, documents, ids, metadatas, commit_hash
)
return deeplake_vs, index, len(documents)
This could use another refactoring pass, but it's way better than when we started. Even better, these functions are 10x easier to test than before, because we can write easy mocks.
Hopefully this blog was useful for you, and if you'd like to see more of the implementation check out our repo at https://github.com/sweepai/sweep (opens in a new tab)!