Skip to content
89 changes: 54 additions & 35 deletions gitrevise/odb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
TYPE_CHECKING,
Dict,
Generic,
Iterator,
Mapping,
Optional,
Sequence,
Expand Down Expand Up @@ -77,7 +78,9 @@ def short(self) -> str:
def for_object(cls, tag: str, body: bytes) -> Oid:
"""Hash an object with the given type tag and body to determine its Oid"""
hasher = hashlib.sha1()
hasher.update(tag.encode() + b" " + str(len(body)).encode() + b"\0" + body)
hasher.update(f"{tag} {len(body)}".encode())
hasher.update(b"\0")
hasher.update(body)
return cls(hasher.digest())

def __repr__(self) -> str:
Expand Down Expand Up @@ -301,54 +304,69 @@ def new_commit(
"""Directly create an in-memory commit object, without persisting it.
If a commit object with these properties already exists, it will be
returned instead."""
if author is None:
author = self.default_author
if committer is None:
committer = self.default_committer

body = b"tree " + tree.oid.hex().encode() + b"\n"
for parent in parents:
body += b"parent " + parent.oid.hex().encode() + b"\n"
body += b"author " + author + b"\n"
body += b"committer " + committer + b"\n"

body_tail = b"\n" + message
body += self.sign_buffer(body + body_tail)
body += body_tail

return Commit(self, body)

def sign_buffer(self, buffer: bytes) -> bytes:
def header_kvs(gpgsig: Optional[bytes]) -> Iterator[Tuple[bytes, bytes]]:
"""Yields each header name and value."""
yield b"tree", tree.oid.hex().encode()
yield from ((b"parent", p.oid.hex().encode()) for p in parents)
yield b"author", author or self.default_author
yield b"committer", committer or self.default_committer
if gpgsig:
yield b"gpgsig", gpgsig

def body_parts(gpgsig: Optional[bytes]) -> Iterator[bytes]:
"""Yields each chunk of the body for rendering into a contiguous buffer."""
for key, value in header_kvs(gpgsig=gpgsig):
# Key, space, value (with embedded newlines indented by space), newline.
yield from (key, b" ", value.replace(b"\n", b"\n "), b"\n")
yield b"\n"
yield message

def build(gpgsig: Optional[bytes] = None) -> bytes:
"""Render the body, optionally including the given gpgsig header."""
return b"".join(body_parts(gpgsig=gpgsig))

def get_body() -> bytes:
# Generate the unsigned body.
unsigned_body = build()
if not self.sign_commits:
return unsigned_body

# Get the signature for the unsigned body.
gpgsig = self.get_gpgsig(unsigned_body)

# Include the signature as a header in the final body.
return build(gpgsig=gpgsig)

return Commit(self, body=get_body())

def get_gpgsig(self, buffer: bytes) -> bytes:
"""Return the text of the signed commit object."""
from .utils import sh_run # pylint: disable=import-outside-toplevel

if not self.sign_commits:
return b""

key_id = self.config(
"user.signingKey", default=self.default_committer.signing_key
)
gpg = None
try:
# See git/gpg-interface.c:sign_buffer for reference
gpg = sh_run(
(self.gpg, "--status-fd=2", "-bsau", key_id),
stdout=PIPE,
stderr=PIPE,
input=buffer,
check=True,
)
except CalledProcessError as gpg:
print(gpg.stderr.decode(), file=sys.stderr, end="")
except CalledProcessError as exc:
print(exc.stderr.decode(), file=sys.stderr, end="")
print("gpg failed to sign commit", file=sys.stderr)
raise

if b"\n[GNUPG:] SIG_CREATED " not in gpg.stderr:
raise GPGSignError(gpg.stderr.decode())

signature = b"gpgsig"
for line in gpg.stdout.splitlines():
signature += b" " + line + b"\n"
return signature
# Strip CR from the line endings, in case we are on Windows.
gpg_out: bytes = gpg.stdout
return gpg_out.replace(b"\r", b"")

def new_tree(self, entries: Mapping[bytes, Entry]) -> Tree:
"""Directly create an in-memory tree object, without persisting it.
Expand All @@ -363,9 +381,11 @@ def entry_key(pair: Tuple[bytes, Entry]) -> bytes:
return name + b"/"
return name

body = b""
for name, entry in sorted(entries.items(), key=entry_key):
body += cast(bytes, entry.mode.value) + b" " + name + b"\0" + entry.oid
body = b"".join(
field
for name, entry in sorted(entries.items(), key=entry_key)
for field in (cast(bytes, entry.mode.value), b" ", name, b"\0", entry.oid)
)
return Tree(self, body)

def get_obj(self, ref: Union[Oid, str]) -> GitObj:
Expand Down Expand Up @@ -565,7 +585,7 @@ def _parse_body(self) -> None:
for hdr in re.split(rb"\n(?! )", hdrs):
# Parse out the key-value pairs from the header, handling
# continuation lines.
key, value = hdr.split(maxsplit=1)
key, value = hdr.split(b" ", maxsplit=1)
value = value.replace(b"\n ", b"\n")

self.gpgsig = None
Expand Down Expand Up @@ -759,9 +779,8 @@ def _parse_body(self) -> None:
while rest:
mode, rest = rest.split(b" ", maxsplit=1)
name, rest = rest.split(b"\0", maxsplit=1)
entry_oid = Oid(rest[:20])
rest = rest[20:]
self.entries[name] = Entry(self.repo, Mode(mode), entry_oid)
eoid, rest = rest[:20], rest[20:]
self.entries[name] = Entry(self.repo, Mode(mode), Oid(eoid))

def _persist_deps(self) -> None:
for entry in self.entries.values():
Expand Down