ugit/internal/ssh/wish.go

178 lines
4.5 KiB
Go
Raw Normal View History

package ssh
import (
"context"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"text/tabwriter"
"go.jolheiser.com/ugit/internal/git"
"github.com/charmbracelet/log"
"github.com/charmbracelet/ssh"
"github.com/charmbracelet/wish"
)
// ErrSystemMalfunction represents a general system error returned to clients.
var ErrSystemMalfunction = errors.New("something went wrong")
// ErrInvalidRepo represents an attempt to access a non-existent repo.
var ErrInvalidRepo = errors.New("invalid repo")
// Hooks is an interface that allows for custom authorization
// implementations and post push/fetch notifications. Prior to git access,
// AuthRepo will be called with the ssh.Session public key and the repo name.
// Implementers return the appropriate AccessLevel.
type Hooks interface {
Push(string, ssh.PublicKey)
Fetch(string, ssh.PublicKey)
}
// Session wraps sn ssh.Session to implement git.ReadWriteContexter
type Session struct {
s ssh.Session
}
// Read implements io.Reader
func (s Session) Read(p []byte) (n int, err error) {
return s.s.Read(p)
}
// Write implements io.Writer
func (s Session) Write(p []byte) (n int, err error) {
return s.s.Write(p)
}
// Close implements io.Closer
func (s Session) Close() error {
return nil
}
// Context returns an interface context.Context
func (s Session) Context() context.Context {
return s.s.Context()
}
// Middleware adds Git server functionality to the ssh.Server. Repos are stored
// in the specified repo directory. The provided Hooks implementation will be
// checked for access on a per repo basis for a ssh.Session public key.
// Hooks.Push and Hooks.Fetch will be called on successful completion of
// their commands.
func Middleware(repoDir string, cloneURL string, port int, gh Hooks) wish.Middleware {
return func(sh ssh.Handler) ssh.Handler {
return func(s ssh.Session) {
sess := Session{s: s}
cmd := s.Command()
// Git operations
if len(cmd) == 2 {
gc := cmd[0]
// repo should be in the form of "repo.git" or "user/repo.git"
repo := strings.TrimSuffix(strings.TrimPrefix(cmd[1], "/"), "/")
repo = filepath.Clean(repo)
if n := strings.Count(repo, "/"); n > 1 {
Fatal(s, ErrInvalidRepo)
return
}
pk := s.PublicKey()
switch gc {
case "git-receive-pack":
if err := gitPack(sess, gc, repoDir, repo); err != nil {
Fatal(s, ErrSystemMalfunction)
}
gh.Push(repo, pk)
return
case "git-upload-archive", "git-upload-pack":
if err := gitPack(sess, gc, repoDir, repo); err != nil {
if errors.Is(err, ErrInvalidRepo) {
Fatal(s, ErrInvalidRepo)
}
log.Error("unknown git error", "error", err)
Fatal(s, ErrSystemMalfunction)
}
gh.Fetch(repo, pk)
return
}
}
// Repo list
if len(cmd) == 0 {
des, err := os.ReadDir(repoDir)
if err != nil && err != fs.ErrNotExist {
log.Error("invalid repository", "error", err)
}
tw := tabwriter.NewWriter(s, 0, 0, 1, ' ', 0)
for _, de := range des {
if filepath.Ext(de.Name()) != ".git" {
continue
}
repo, err := git.NewRepo(repoDir, de.Name())
visibility := "❓"
if err == nil {
visibility = "🔓"
if repo.Meta.Private {
visibility = "🔒"
}
}
fmt.Fprintf(tw, "%[1]s\t%[3]s\t%[2]s/%[1]s.git\n", strings.TrimSuffix(de.Name(), ".git"), cloneURL, visibility)
}
tw.Flush()
}
sh(s)
}
}
}
func gitPack(s Session, gitCmd string, repoDir string, repoName string) error {
rp := filepath.Join(repoDir, repoName)
protocol, err := git.NewProtocol(rp)
if err != nil {
return err
}
switch gitCmd {
case "git-upload-pack":
exists, err := git.PathExists(rp)
if !exists {
return ErrInvalidRepo
}
if err != nil {
return err
}
return protocol.SSHUploadPack(s)
case "git-receive-pack":
err := git.EnsureRepo(repoDir, repoName)
if err != nil {
return err
}
repo, err := git.NewRepo(repoDir, repoName)
if err != nil {
return err
}
err = protocol.SSHReceivePack(s, repo)
if err != nil {
return err
}
_, err = repo.DefaultBranch()
if err != nil {
return err
}
// Needed for git dumb http server
return git.UpdateServerInfo(rp)
default:
return fmt.Errorf("unknown git command: %s", gitCmd)
}
}
// Fatal prints to the session's STDOUT as a git response and exit 1.
func Fatal(s ssh.Session, v ...interface{}) {
msg := fmt.Sprint(v...)
// hex length includes 4 byte length prefix and ending newline
pktLine := fmt.Sprintf("%04x%s\n", len(msg)+5, msg)
_, _ = wish.WriteString(s, pktLine)
s.Exit(1) // nolint: errcheck
}