From 2c8bcc163e1ed047af7c43edee9ca25f250df30e Mon Sep 17 00:00:00 2001
From: Gusted <postmaster@gusted.xyz>
Date: Fri, 29 Mar 2024 00:20:21 +0100
Subject: [PATCH] [REFACTOR] PKT protocol

- Use `Fprintf` to convert to hex and do padding. Simplifies the code.
- Use `Read()` and `io.ReadFull` instead of `ReadByte()`. Should improve
performance and allows for cleaner code.
- s/pktLineTypeUnknow/pktLineTypeUnknown.
- Disallow empty Pkt line per the specification.
- Disallow too large Pkt line per the specification.
- Add unit tests.
---
 cmd/hook.go      | 54 ++++++++++++++----------------
 cmd/hook_test.go | 85 ++++++++++++++++++++++++++++++++++++------------
 cmd/serv.go      |  6 +++-
 3 files changed, 93 insertions(+), 52 deletions(-)

diff --git a/cmd/hook.go b/cmd/hook.go
index 966e4a57ca..3be480618f 100644
--- a/cmd/hook.go
+++ b/cmd/hook.go
@@ -583,7 +583,7 @@ Forgejo or set your environment appropriately.`, "")
 
 	for {
 		// note: pktLineTypeUnknow means pktLineTypeFlush and pktLineTypeData all allowed
-		rs, err = readPktLine(ctx, reader, pktLineTypeUnknow)
+		rs, err = readPktLine(ctx, reader, pktLineTypeUnknown)
 		if err != nil {
 			return err
 		}
@@ -604,7 +604,7 @@ Forgejo or set your environment appropriately.`, "")
 
 	if hasPushOptions {
 		for {
-			rs, err = readPktLine(ctx, reader, pktLineTypeUnknow)
+			rs, err = readPktLine(ctx, reader, pktLineTypeUnknown)
 			if err != nil {
 				return err
 			}
@@ -699,8 +699,8 @@ Forgejo or set your environment appropriately.`, "")
 type pktLineType int64
 
 const (
-	// UnKnow type
-	pktLineTypeUnknow pktLineType = 0
+	// Unknown type
+	pktLineTypeUnknown pktLineType = 0
 	// flush-pkt "0000"
 	pktLineTypeFlush pktLineType = iota
 	// data line
@@ -714,22 +714,16 @@ type gitPktLine struct {
 	Data   []byte
 }
 
+// Reads an Pkt-Line from `in`. If requestType is not unknown, it will a
 func readPktLine(ctx context.Context, in *bufio.Reader, requestType pktLineType) (*gitPktLine, error) {
-	var (
-		err error
-		r   *gitPktLine
-	)
-
-	// read prefix
+	// Read length prefix
 	lengthBytes := make([]byte, 4)
-	for i := 0; i < 4; i++ {
-		lengthBytes[i], err = in.ReadByte()
-		if err != nil {
-			return nil, fail(ctx, "Protocol: stdin error", "Pkt-Line: read stdin failed : %v", err)
-		}
+	if n, err := in.Read(lengthBytes); n != 4 || err != nil {
+		return nil, fail(ctx, "Protocol: stdin error", "Pkt-Line: read stdin failed : %v", err)
 	}
 
-	r = new(gitPktLine)
+	var err error
+	r := &gitPktLine{}
 	r.Length, err = strconv.ParseUint(string(lengthBytes), 16, 32)
 	if err != nil {
 		return nil, fail(ctx, "Protocol: format parse error", "Pkt-Line format is wrong :%v", err)
@@ -748,11 +742,8 @@ func readPktLine(ctx context.Context, in *bufio.Reader, requestType pktLineType)
 	}
 
 	r.Data = make([]byte, r.Length-4)
-	for i := range r.Data {
-		r.Data[i], err = in.ReadByte()
-		if err != nil {
-			return nil, fail(ctx, "Protocol: data error", "Pkt-Line: read stdin failed : %v", err)
-		}
+	if n, err := io.ReadFull(in, r.Data); uint64(n) != r.Length-4 || err != nil {
+		return nil, fail(ctx, "Protocol: stdin error", "Pkt-Line: read stdin failed : %v", err)
 	}
 
 	r.Type = pktLineTypeData
@@ -768,20 +759,23 @@ func writeFlushPktLine(ctx context.Context, out io.Writer) error {
 	return nil
 }
 
+// Write an Pkt-Line based on `data` to `out` according to the specifcation.
+// https://git-scm.com/docs/protocol-common
 func writeDataPktLine(ctx context.Context, out io.Writer, data []byte) error {
-	hexchar := []byte("0123456789abcdef")
-	hex := func(n uint64) byte {
-		return hexchar[(n)&15]
+	// Implementations SHOULD NOT send an empty pkt-line ("0004").
+	if len(data) == 0 {
+		return fail(ctx, "Protocol: write error", "Not allowed to write empty Pkt-Line")
 	}
 
 	length := uint64(len(data) + 4)
-	tmp := make([]byte, 4)
-	tmp[0] = hex(length >> 12)
-	tmp[1] = hex(length >> 8)
-	tmp[2] = hex(length >> 4)
-	tmp[3] = hex(length)
 
-	lr, err := out.Write(tmp)
+	// The maximum length of a pkt-line’s data component is 65516 bytes.
+	// Implementations MUST NOT send pkt-line whose length exceeds 65520 (65516 bytes of payload + 4 bytes of length data).
+	if length > 65520 {
+		return fail(ctx, "Protocol: write error", "Pkt-Line exceeds maximum of 65520 bytes")
+	}
+
+	lr, err := fmt.Fprintf(out, "%04x", length)
 	if err != nil || lr != 4 {
 		return fail(ctx, "Protocol: write error", "Pkt-Line response failed: %v", err)
 	}
diff --git a/cmd/hook_test.go b/cmd/hook_test.go
index 91f24ff2b4..d4e16dc411 100644
--- a/cmd/hook_test.go
+++ b/cmd/hook_test.go
@@ -14,29 +14,72 @@ import (
 )
 
 func TestPktLine(t *testing.T) {
-	// test read
 	ctx := context.Background()
-	s := strings.NewReader("0000")
-	r := bufio.NewReader(s)
-	result, err := readPktLine(ctx, r, pktLineTypeFlush)
-	assert.NoError(t, err)
-	assert.Equal(t, pktLineTypeFlush, result.Type)
 
-	s = strings.NewReader("0006a\n")
-	r = bufio.NewReader(s)
-	result, err = readPktLine(ctx, r, pktLineTypeData)
-	assert.NoError(t, err)
-	assert.Equal(t, pktLineTypeData, result.Type)
-	assert.Equal(t, []byte("a\n"), result.Data)
+	t.Run("Read", func(t *testing.T) {
+		s := strings.NewReader("0000")
+		r := bufio.NewReader(s)
+		result, err := readPktLine(ctx, r, pktLineTypeFlush)
+		assert.NoError(t, err)
+		assert.Equal(t, pktLineTypeFlush, result.Type)
 
-	// test write
-	w := bytes.NewBuffer([]byte{})
-	err = writeFlushPktLine(ctx, w)
-	assert.NoError(t, err)
-	assert.Equal(t, []byte("0000"), w.Bytes())
+		s = strings.NewReader("0006a\n")
+		r = bufio.NewReader(s)
+		result, err = readPktLine(ctx, r, pktLineTypeData)
+		assert.NoError(t, err)
+		assert.Equal(t, pktLineTypeData, result.Type)
+		assert.Equal(t, []byte("a\n"), result.Data)
 
-	w.Reset()
-	err = writeDataPktLine(ctx, w, []byte("a\nb"))
-	assert.NoError(t, err)
-	assert.Equal(t, []byte("0007a\nb"), w.Bytes())
+		s = strings.NewReader("0004")
+		r = bufio.NewReader(s)
+		result, err = readPktLine(ctx, r, pktLineTypeData)
+		assert.Error(t, err)
+		assert.Nil(t, result)
+
+		data := strings.Repeat("x", 65516)
+		r = bufio.NewReader(strings.NewReader("fff0" + data))
+		result, err = readPktLine(ctx, r, pktLineTypeData)
+		assert.NoError(t, err)
+		assert.Equal(t, pktLineTypeData, result.Type)
+		assert.Equal(t, []byte(data), result.Data)
+
+		r = bufio.NewReader(strings.NewReader("fff1a"))
+		result, err = readPktLine(ctx, r, pktLineTypeData)
+		assert.Error(t, err)
+		assert.Nil(t, result)
+	})
+
+	t.Run("Write", func(t *testing.T) {
+		w := bytes.NewBuffer([]byte{})
+		err := writeFlushPktLine(ctx, w)
+		assert.NoError(t, err)
+		assert.Equal(t, []byte("0000"), w.Bytes())
+
+		w.Reset()
+		err = writeDataPktLine(ctx, w, []byte("a\nb"))
+		assert.NoError(t, err)
+		assert.Equal(t, []byte("0007a\nb"), w.Bytes())
+
+		w.Reset()
+		data := bytes.Repeat([]byte{0x05}, 288)
+		err = writeDataPktLine(ctx, w, data)
+		assert.NoError(t, err)
+		assert.Equal(t, append([]byte("0124"), data...), w.Bytes())
+
+		w.Reset()
+		err = writeDataPktLine(ctx, w, nil)
+		assert.Error(t, err)
+		assert.Empty(t, w.Bytes())
+
+		w.Reset()
+		data = bytes.Repeat([]byte{0x64}, 65516)
+		err = writeDataPktLine(ctx, w, data)
+		assert.NoError(t, err)
+		assert.Equal(t, append([]byte("fff0"), data...), w.Bytes())
+
+		w.Reset()
+		err = writeDataPktLine(ctx, w, bytes.Repeat([]byte{0x64}, 65516+1))
+		assert.Error(t, err)
+		assert.Empty(t, w.Bytes())
+	})
 }
diff --git a/cmd/serv.go b/cmd/serv.go
index 9d26515254..d5c54f91b8 100644
--- a/cmd/serv.go
+++ b/cmd/serv.go
@@ -14,6 +14,7 @@ import (
 	"regexp"
 	"strconv"
 	"strings"
+	"testing"
 	"time"
 	"unicode"
 
@@ -106,7 +107,10 @@ func fail(ctx context.Context, userMessage, logMsgFmt string, args ...any) error
 				logMsg = userMessage + ". " + logMsg
 			}
 		}
-		_ = private.SSHLog(ctx, true, logMsg)
+		// Don't send an log if this is done in a test and no InternalToken is set.
+		if !testing.Testing() || setting.InternalToken != "" {
+			_ = private.SSHLog(ctx, true, logMsg)
+		}
 	}
 	return cli.Exit("", 1)
 }