Skip to content

Commit 0fe5e5f

Browse files
Merge branch 'main' into fix/duplicate-global-assignments-when-reverting-helpers
2 parents c1f75ad + 02ae60b commit 0fe5e5f

File tree

9 files changed

+407
-107
lines changed

9 files changed

+407
-107
lines changed

codeflash/code_utils/git_utils.py

Lines changed: 8 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,24 @@
1616
from unidiff import PatchSet
1717

1818
from codeflash.cli_cmds.console import logger
19-
from codeflash.code_utils.compat import codeflash_cache_dir
2019
from codeflash.code_utils.config_consts import N_CANDIDATES
2120

2221
if TYPE_CHECKING:
2322
from git import Repo
2423

2524

26-
def get_git_diff(repo_directory: Path | None = None, *, uncommitted_changes: bool = False) -> dict[str, list[int]]:
25+
def get_git_diff(
26+
repo_directory: Path | None = None, *, only_this_commit: Optional[str] = None, uncommitted_changes: bool = False
27+
) -> dict[str, list[int]]:
2728
if repo_directory is None:
2829
repo_directory = Path.cwd()
2930
repository = git.Repo(repo_directory, search_parent_directories=True)
3031
commit = repository.head.commit
31-
if uncommitted_changes:
32+
if only_this_commit:
33+
uni_diff_text = repository.git.diff(
34+
only_this_commit + "^1", only_this_commit, ignore_blank_lines=True, ignore_space_at_eol=True
35+
)
36+
elif uncommitted_changes:
3237
uni_diff_text = repository.git.diff(None, "HEAD", ignore_blank_lines=True, ignore_space_at_eol=True)
3338
else:
3439
uni_diff_text = repository.git.diff(
@@ -193,84 +198,3 @@ def get_last_commit_author_if_pr_exists(repo: Repo | None = None) -> str | None:
193198
return None
194199
else:
195200
return last_commit.author.name
196-
197-
198-
worktree_dirs = codeflash_cache_dir / "worktrees"
199-
patches_dir = codeflash_cache_dir / "patches"
200-
201-
202-
def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None:
203-
repository = git.Repo(worktree_dir, search_parent_directories=True)
204-
repository.git.add(".")
205-
repository.git.commit("-m", commit_message, "--no-verify")
206-
207-
208-
def create_detached_worktree(module_root: Path) -> Optional[Path]:
209-
if not check_running_in_git_repo(module_root):
210-
logger.warning("Module is not in a git repository. Skipping worktree creation.")
211-
return None
212-
git_root = git_root_dir()
213-
current_time_str = time.strftime("%Y%m%d-%H%M%S")
214-
worktree_dir = worktree_dirs / f"{git_root.name}-{current_time_str}"
215-
216-
repository = git.Repo(git_root, search_parent_directories=True)
217-
218-
repository.git.worktree("add", "-d", str(worktree_dir))
219-
220-
# Get uncommitted diff from the original repo
221-
repository.git.add("-N", ".") # add the index for untracked files to be included in the diff
222-
exclude_binary_files = [":!*.pyc", ":!*.pyo", ":!*.pyd", ":!*.so", ":!*.dll", ":!*.whl", ":!*.egg", ":!*.egg-info", ":!*.pyz", ":!*.pkl", ":!*.pickle", ":!*.joblib", ":!*.npy", ":!*.npz", ":!*.h5", ":!*.hdf5", ":!*.pth", ":!*.pt", ":!*.pb", ":!*.onnx", ":!*.db", ":!*.sqlite", ":!*.sqlite3", ":!*.feather", ":!*.parquet", ":!*.jpg", ":!*.jpeg", ":!*.png", ":!*.gif", ":!*.bmp", ":!*.tiff", ":!*.webp", ":!*.wav", ":!*.mp3", ":!*.ogg", ":!*.flac", ":!*.mp4", ":!*.avi", ":!*.mov", ":!*.mkv", ":!*.pdf", ":!*.doc", ":!*.docx", ":!*.xls", ":!*.xlsx", ":!*.ppt", ":!*.pptx", ":!*.zip", ":!*.rar", ":!*.tar", ":!*.tar.gz", ":!*.tgz", ":!*.bz2", ":!*.xz"] # fmt: off
223-
uni_diff_text = repository.git.diff(
224-
None, "HEAD", "--", *exclude_binary_files, ignore_blank_lines=True, ignore_space_at_eol=True
225-
)
226-
227-
if not uni_diff_text.strip():
228-
logger.info("No uncommitted changes to copy to worktree.")
229-
return worktree_dir
230-
231-
# Write the diff to a temporary file
232-
with tempfile.NamedTemporaryFile(mode="w", suffix=".codeflash.patch", delete=False) as tmp_patch_file:
233-
tmp_patch_file.write(uni_diff_text + "\n") # the new line here is a must otherwise the last hunk won't be valid
234-
tmp_patch_file.flush()
235-
236-
patch_path = Path(tmp_patch_file.name).resolve()
237-
238-
# Apply the patch inside the worktree
239-
try:
240-
subprocess.run(
241-
["git", "apply", "--ignore-space-change", "--ignore-whitespace", "--whitespace=nowarn", patch_path],
242-
cwd=worktree_dir,
243-
check=True,
244-
)
245-
create_worktree_snapshot_commit(worktree_dir, "Initial Snapshot")
246-
except subprocess.CalledProcessError as e:
247-
logger.error(f"Failed to apply patch to worktree: {e}")
248-
249-
return worktree_dir
250-
251-
252-
def remove_worktree(worktree_dir: Path) -> None:
253-
try:
254-
repository = git.Repo(worktree_dir, search_parent_directories=True)
255-
repository.git.worktree("remove", "--force", worktree_dir)
256-
except Exception:
257-
logger.exception(f"Failed to remove worktree: {worktree_dir}")
258-
259-
260-
def create_diff_patch_from_worktree(worktree_dir: Path, files: list[str], fto_name: str) -> Path:
261-
repository = git.Repo(worktree_dir, search_parent_directories=True)
262-
uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True)
263-
264-
if not uni_diff_text:
265-
logger.warning("No changes found in worktree.")
266-
return None
267-
268-
if not uni_diff_text.endswith("\n"):
269-
uni_diff_text += "\n"
270-
271-
# write to patches_dir
272-
patches_dir.mkdir(parents=True, exist_ok=True)
273-
patch_path = patches_dir / f"{worktree_dir.name}.{fto_name}.patch"
274-
with patch_path.open("w", encoding="utf8") as f:
275-
f.write(uni_diff_text)
276-
return patch_path
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
from __future__ import annotations
2+
3+
import json
4+
import subprocess
5+
import tempfile
6+
import time
7+
from functools import lru_cache
8+
from pathlib import Path
9+
from typing import TYPE_CHECKING, Optional
10+
11+
import git
12+
from filelock import FileLock
13+
14+
from codeflash.cli_cmds.console import logger
15+
from codeflash.code_utils.compat import codeflash_cache_dir
16+
from codeflash.code_utils.git_utils import check_running_in_git_repo, git_root_dir
17+
18+
if TYPE_CHECKING:
19+
from typing import Any
20+
21+
from git import Repo
22+
23+
24+
worktree_dirs = codeflash_cache_dir / "worktrees"
25+
patches_dir = codeflash_cache_dir / "patches"
26+
27+
if TYPE_CHECKING:
28+
from git import Repo
29+
30+
31+
@lru_cache(maxsize=1)
32+
def get_git_project_id() -> str:
33+
"""Return the first commit sha of the repo."""
34+
repo: Repo = git.Repo(search_parent_directories=True)
35+
root_commits = list(repo.iter_commits(rev="HEAD", max_parents=0))
36+
return root_commits[0].hexsha
37+
38+
39+
def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None:
40+
repository = git.Repo(worktree_dir, search_parent_directories=True)
41+
repository.git.add(".")
42+
repository.git.commit("-m", commit_message, "--no-verify")
43+
44+
45+
def create_detached_worktree(module_root: Path) -> Optional[Path]:
46+
if not check_running_in_git_repo(module_root):
47+
logger.warning("Module is not in a git repository. Skipping worktree creation.")
48+
return None
49+
git_root = git_root_dir()
50+
current_time_str = time.strftime("%Y%m%d-%H%M%S")
51+
worktree_dir = worktree_dirs / f"{git_root.name}-{current_time_str}"
52+
53+
repository = git.Repo(git_root, search_parent_directories=True)
54+
55+
repository.git.worktree("add", "-d", str(worktree_dir))
56+
57+
# Get uncommitted diff from the original repo
58+
repository.git.add("-N", ".") # add the index for untracked files to be included in the diff
59+
exclude_binary_files = [":!*.pyc", ":!*.pyo", ":!*.pyd", ":!*.so", ":!*.dll", ":!*.whl", ":!*.egg", ":!*.egg-info", ":!*.pyz", ":!*.pkl", ":!*.pickle", ":!*.joblib", ":!*.npy", ":!*.npz", ":!*.h5", ":!*.hdf5", ":!*.pth", ":!*.pt", ":!*.pb", ":!*.onnx", ":!*.db", ":!*.sqlite", ":!*.sqlite3", ":!*.feather", ":!*.parquet", ":!*.jpg", ":!*.jpeg", ":!*.png", ":!*.gif", ":!*.bmp", ":!*.tiff", ":!*.webp", ":!*.wav", ":!*.mp3", ":!*.ogg", ":!*.flac", ":!*.mp4", ":!*.avi", ":!*.mov", ":!*.mkv", ":!*.pdf", ":!*.doc", ":!*.docx", ":!*.xls", ":!*.xlsx", ":!*.ppt", ":!*.pptx", ":!*.zip", ":!*.rar", ":!*.tar", ":!*.tar.gz", ":!*.tgz", ":!*.bz2", ":!*.xz"] # fmt: off
60+
uni_diff_text = repository.git.diff(
61+
None, "HEAD", "--", *exclude_binary_files, ignore_blank_lines=True, ignore_space_at_eol=True
62+
)
63+
64+
if not uni_diff_text.strip():
65+
logger.info("No uncommitted changes to copy to worktree.")
66+
return worktree_dir
67+
68+
# Write the diff to a temporary file
69+
with tempfile.NamedTemporaryFile(mode="w", suffix=".codeflash.patch", delete=False) as tmp_patch_file:
70+
tmp_patch_file.write(uni_diff_text + "\n") # the new line here is a must otherwise the last hunk won't be valid
71+
tmp_patch_file.flush()
72+
73+
patch_path = Path(tmp_patch_file.name).resolve()
74+
75+
# Apply the patch inside the worktree
76+
try:
77+
subprocess.run(
78+
["git", "apply", "--ignore-space-change", "--ignore-whitespace", "--whitespace=nowarn", patch_path],
79+
cwd=worktree_dir,
80+
check=True,
81+
)
82+
create_worktree_snapshot_commit(worktree_dir, "Initial Snapshot")
83+
except subprocess.CalledProcessError as e:
84+
logger.error(f"Failed to apply patch to worktree: {e}")
85+
86+
return worktree_dir
87+
88+
89+
def remove_worktree(worktree_dir: Path) -> None:
90+
try:
91+
repository = git.Repo(worktree_dir, search_parent_directories=True)
92+
repository.git.worktree("remove", "--force", worktree_dir)
93+
except Exception:
94+
logger.exception(f"Failed to remove worktree: {worktree_dir}")
95+
96+
97+
@lru_cache(maxsize=1)
98+
def get_patches_dir_for_project() -> Path:
99+
project_id = get_git_project_id() or ""
100+
return Path(patches_dir / project_id)
101+
102+
103+
def get_patches_metadata() -> dict[str, Any]:
104+
project_patches_dir = get_patches_dir_for_project()
105+
meta_file = project_patches_dir / "metadata.json"
106+
if meta_file.exists():
107+
with meta_file.open("r", encoding="utf-8") as f:
108+
return json.load(f)
109+
return {"id": get_git_project_id() or "", "patches": []}
110+
111+
112+
def save_patches_metadata(patch_metadata: dict) -> dict:
113+
project_patches_dir = get_patches_dir_for_project()
114+
meta_file = project_patches_dir / "metadata.json"
115+
lock_file = project_patches_dir / "metadata.json.lock"
116+
117+
# we are not supporting multiple concurrent optimizations within the same process, but keep that in case we decide to do so in the future.
118+
with FileLock(lock_file, timeout=10):
119+
metadata = get_patches_metadata()
120+
121+
patch_metadata["id"] = time.strftime("%Y%m%d-%H%M%S")
122+
metadata["patches"].append(patch_metadata)
123+
124+
meta_file.write_text(json.dumps(metadata, indent=2))
125+
126+
return patch_metadata
127+
128+
129+
def overwrite_patch_metadata(patches: list[dict]) -> bool:
130+
project_patches_dir = get_patches_dir_for_project()
131+
meta_file = project_patches_dir / "metadata.json"
132+
lock_file = project_patches_dir / "metadata.json.lock"
133+
134+
with FileLock(lock_file, timeout=10):
135+
metadata = get_patches_metadata()
136+
metadata["patches"] = patches
137+
meta_file.write_text(json.dumps(metadata, indent=2))
138+
return True
139+
140+
141+
def create_diff_patch_from_worktree(
142+
worktree_dir: Path,
143+
files: list[str],
144+
fto_name: Optional[str] = None,
145+
metadata_input: Optional[dict[str, Any]] = None,
146+
) -> dict[str, Any]:
147+
repository = git.Repo(worktree_dir, search_parent_directories=True)
148+
uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True)
149+
150+
if not uni_diff_text:
151+
logger.warning("No changes found in worktree.")
152+
return {}
153+
154+
if not uni_diff_text.endswith("\n"):
155+
uni_diff_text += "\n"
156+
157+
project_patches_dir = get_patches_dir_for_project()
158+
project_patches_dir.mkdir(parents=True, exist_ok=True)
159+
160+
final_function_name = fto_name or metadata_input.get("fto_name", "unknown")
161+
patch_path = project_patches_dir / f"{worktree_dir.name}.{final_function_name}.patch"
162+
with patch_path.open("w", encoding="utf8") as f:
163+
f.write(uni_diff_text)
164+
165+
final_metadata = {"patch_path": str(patch_path)}
166+
if metadata_input:
167+
final_metadata.update(metadata_input)
168+
final_metadata = save_patches_metadata(final_metadata)
169+
170+
return final_metadata

codeflash/code_utils/shell_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
SHELL_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.MULTILINE)
1616
SHELL_RC_EXPORT_PREFIX = "set CODEFLASH_API_KEY="
1717
else:
18-
SHELL_RC_EXPORT_PATTERN = re.compile(r'^(?!#)export CODEFLASH_API_KEY=[\'"]?(cf-[^\s"]+)[\'"]$', re.MULTILINE)
18+
SHELL_RC_EXPORT_PATTERN = re.compile(
19+
r'^(?!#)export CODEFLASH_API_KEY=(?:"|\')?(cf-[^\s"\']+)(?:"|\')?$', re.MULTILINE
20+
)
1921
SHELL_RC_EXPORT_PREFIX = "export CODEFLASH_API_KEY="
2022

2123

codeflash/discovery/functions_to_optimize.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,16 @@ def get_functions_to_optimize(
232232

233233
def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[FunctionToOptimize]]: # noqa: FBT001
234234
modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=uncommitted_changes)
235-
modified_functions: dict[str, list[FunctionToOptimize]] = {}
235+
return get_functions_within_lines(modified_lines)
236+
237+
238+
def get_functions_inside_a_commit(commit_hash: str) -> dict[str, list[FunctionToOptimize]]:
239+
modified_lines: dict[str, list[int]] = get_git_diff(only_this_commit=commit_hash)
240+
return get_functions_within_lines(modified_lines)
241+
242+
243+
def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[str, list[FunctionToOptimize]]:
244+
functions: dict[str, list[FunctionToOptimize]] = {}
236245
for path_str, lines_in_file in modified_lines.items():
237246
path = Path(path_str)
238247
if not path.exists():
@@ -246,14 +255,14 @@ def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[F
246255
continue
247256
function_lines = FunctionVisitor(file_path=str(path))
248257
wrapper.visit(function_lines)
249-
modified_functions[str(path)] = [
258+
functions[str(path)] = [
250259
function_to_optimize
251260
for function_to_optimize in function_lines.functions
252261
if (start_line := function_to_optimize.starting_line) is not None
253262
and (end_line := function_to_optimize.ending_line) is not None
254263
and any(start_line <= line <= end_line for line in lines_in_file)
255264
]
256-
return modified_functions
265+
return functions
257266

258267

259268
def get_all_files_and_functions(module_root_path: Path) -> dict[str, list[FunctionToOptimize]]:

0 commit comments

Comments
 (0)