Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 109 additions & 9 deletions src/ghstack/pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,21 @@ async def _find_head_with_tree(
commit, commit_tree = line.split()
if commit_tree == tree:
return commit

# Some historical or externally rewritten ghstack heads do not keep the
# old head commit in first-parent history. The source id is still enough
# to recover the merge base as long as the tree object is present locally.
if await sh.agit("cat-file", "-e", f"{tree}^{{tree}}", exitcode=True):
return await sh.agit(
"commit-tree",
tree,
input="Synthetic ghstack pull merge base\n\n[ghstack-poisoned]\n",
)

raise RuntimeError(
"Could not find the previously checked out ghstack head commit. "
"The local ghstack-source-id does not appear in the remote head history."
"The local ghstack-source-id does not appear in the remote head history, "
"and the corresponding tree is not available locally."
)


Expand Down Expand Up @@ -147,18 +159,33 @@ async def _finish_pull(sh: ghstack.shell.Shell, state: Dict[str, Any]) -> None:
await _clear_state(sh)


async def main(
async def _local_ghstack_stack(
sh: ghstack.shell.Shell, *, github_url: str
) -> List[str]:
stack = []
commit = "HEAD"
while True:
commit_msg = await sh.agit("log", "-1", "--format=%B", commit)
if ghstack.diff.PullRequestResolved.search(commit_msg, github_url) is None:
break
commit_hash = await sh.agit("rev-parse", commit)
stack.append(commit_hash)
parents = (await sh.agit("rev-list", "--parents", "-n", "1", commit)).split()
if len(parents) != 2:
break
commit = parents[1]
stack.reverse()
return stack


async def _pull_current(
github: ghstack.github.GitHubEndpoint,
sh: ghstack.shell.Shell,
remote_name: str,
github_url: str,
pull_request: Optional[str] = None,
continue_: bool = False,
parent_override: Optional[str] = None,
) -> None:
if continue_:
await _finish_pull(sh, await _read_state(sh))
return

params = await _resolve_params(
pull_request=pull_request,
github_url=github_url,
Expand Down Expand Up @@ -216,7 +243,15 @@ async def main(

returncode, merge_tree_output = await _run_git_for_status(
sh,
["merge-tree", "--write-tree", "--messages", remote_head, local_imputed_head],
[
"merge-tree",
"--write-tree",
"--messages",
"--merge-base",
old_head,
remote_head,
local_imputed_head,
],
)
merged_tree = merge_tree_output.splitlines()[0] if returncode == 0 else None

Expand All @@ -229,7 +264,11 @@ async def main(
if m_remote_source_id is not None
else await sh.agit("rev-parse", f"{remote_orig}^{{tree}}")
)
remote_orig_parent = await sh.agit("rev-parse", f"{remote_orig}^")
remote_orig_parent = (
parent_override
if parent_override is not None
else await sh.agit("rev-parse", f"{remote_orig}^")
)

author_name = await sh.agit("log", "-1", "--format=%an", "HEAD")
author_email = await sh.agit("log", "-1", "--format=%ae", "HEAD")
Expand Down Expand Up @@ -268,3 +307,64 @@ async def main(
},
)
await sh.agit("checkout", pulled_orig)


async def _pull_stack(
github: ghstack.github.GitHubEndpoint,
sh: ghstack.shell.Shell,
remote_name: str,
github_url: str,
stack: List[str],
) -> None:
current_head: Optional[str] = None
for i, commit in enumerate(stack):
if i == 0:
await sh.agit("checkout", commit)
parent_override = None
else:
assert current_head is not None
await sh.agit("checkout", current_head)
await sh.agit("cherry-pick", commit)
parent_override = await sh.agit("rev-parse", "HEAD^")

await _pull_current(
github=github,
sh=sh,
remote_name=remote_name,
github_url=github_url,
parent_override=parent_override,
)
current_head = await sh.agit("rev-parse", "HEAD")


async def main(
github: ghstack.github.GitHubEndpoint,
sh: ghstack.shell.Shell,
remote_name: str,
github_url: str,
pull_request: Optional[str] = None,
continue_: bool = False,
) -> None:
if continue_:
await _finish_pull(sh, await _read_state(sh))
return

if pull_request is None:
stack = await _local_ghstack_stack(sh, github_url=github_url)
if len(stack) > 1:
await _pull_stack(
github=github,
sh=sh,
remote_name=remote_name,
github_url=github_url,
stack=stack,
)
return

await _pull_current(
github=github,
sh=sh,
remote_name=remote_name,
github_url=github_url,
pull_request=pull_request,
)
38 changes: 38 additions & 0 deletions test/pull/rewritten_head_history.py.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from ghstack.test_prelude import *

await init_test()

await commit("A")
(A,) = await gh_submit("Initial")
old_orig = A.orig

await write_file_and_add("remote.txt", "remote change")
await git("commit", "--amend", "--no-edit")
await gh_submit("Remote update")

remote_tree = await git("rev-parse", "origin/gh/ezyang/1/head^{tree}")
rewritten_head = await get_upstream_sh().agit(
"commit-tree",
"-p",
"gh/ezyang/1/base",
remote_tree,
input="Rewritten remote head\n\n[ghstack-poisoned]\n",
)
await get_upstream_sh().agit(
"update-ref",
"refs/heads/gh/ezyang/1/head",
rewritten_head,
)

await checkout(old_orig)
await write_file_and_add("local.txt", "local change")
await git("commit", "--amend", "--no-edit")

await gh_pull()

assert_eq(await git("show", "HEAD:remote.txt"), "remote change")
assert_eq(await git("show", "HEAD:local.txt"), "local change")

await gh_submit("Local update")

ok()
51 changes: 51 additions & 0 deletions test/pull/stack_after_top_only_pull.py.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from ghstack.test_prelude import *

await init_test()

await commit("A")
await commit("B")
A, B = await gh_submit("Initial")
old_top = B.orig

await checkout(A)
await write_file_and_add("A.txt", "remote A")
await git("commit", "--amend", "--no-edit")
await gh_submit("Remote bottom update")

await checkout(old_top)
await write_file_and_add("B.txt", "local B")
await git("commit", "--amend", "--no-edit")

# Simulate the old top-only behavior: the top PR was pulled/recreated, but
# the lower PR in the local stack remained stale.
await gh_pull(f"https://github.com/pytorch/pytorch/pull/{B.number}")
assert_eq(
await git("rev-parse", "HEAD^^{tree}"),
await git("rev-parse", A.orig + "^{tree}"),
)

await gh_pull()

assert_eq(await git("show", "HEAD:A.txt"), "remote A")
assert_eq(await git("show", "HEAD:B.txt"), "local B")
assert_eq(
await git("rev-parse", "HEAD^^{tree}"),
await git("rev-parse", "origin/gh/ezyang/1/orig^{tree}"),
)
assert (
"ghstack-source-id: "
+ await git("rev-parse", "origin/gh/ezyang/1/orig^{tree}")
in await git("log", "-1", "--format=%B", "HEAD^")
)
assert (
"ghstack-source-id: "
+ await git("rev-parse", "origin/gh/ezyang/2/orig^{tree}")
in await git("log", "-1", "--format=%B", "HEAD")
)
pulled_tree = await git("rev-parse", "HEAD^{tree}")
await gh_pull()
assert_eq(await git("rev-parse", "HEAD^{tree}"), pulled_tree)

await gh_submit("Local update")

ok()
44 changes: 44 additions & 0 deletions test/pull/stack_parent_update.py.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from ghstack.test_prelude import *

await init_test()

await commit("A")
await commit("B")
A, B = await gh_submit("Initial")
old_top = B.orig

await checkout(A)
await write_file_and_add("A.txt", "remote A")
await git("commit", "--amend", "--no-edit")
await gh_submit("Remote bottom update")

await checkout(old_top)
await write_file_and_add("B.txt", "local B")
await git("commit", "--amend", "--no-edit")

await gh_pull()

assert_eq(await git("show", "HEAD:A.txt"), "remote A")
assert_eq(await git("show", "HEAD:B.txt"), "local B")
assert_eq(
await git("rev-parse", "HEAD^^"),
await git("rev-parse", "origin/main"),
)
assert_eq(
await git("rev-parse", "HEAD^^{tree}"),
await git("rev-parse", "origin/gh/ezyang/1/orig^{tree}"),
)
assert (
"ghstack-source-id: "
+ await git("rev-parse", "origin/gh/ezyang/1/orig^{tree}")
in await git("log", "-1", "--format=%B", "HEAD^")
)
assert (
"ghstack-source-id: "
+ await git("rev-parse", "origin/gh/ezyang/2/orig^{tree}")
in await git("log", "-1", "--format=%B", "HEAD")
)

await gh_submit("Local update")

ok()
Loading