122 lines
3.2 KiB
Go
122 lines
3.2 KiB
Go
package ssh
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"os/signal"
|
|
"syscall"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type SSHConfig struct {
|
|
User string
|
|
Host string
|
|
Port string
|
|
Password string
|
|
PrivateKey string
|
|
}
|
|
|
|
// ConnectSSH establishes an SSH connection and starts an interactive session.
|
|
func ConnectSSH(config SSHConfig) error {
|
|
log.Printf("Attempting SSH connection to %s@%s:%s", config.User, config.Host, config.Port)
|
|
|
|
// Build SSH authentication methods
|
|
var authMethods []ssh.AuthMethod
|
|
|
|
// Add key-based authentication if a private key is provided
|
|
if config.PrivateKey != "" {
|
|
log.Println("Using private key authentication...")
|
|
key, err := os.ReadFile(config.PrivateKey)
|
|
if err != nil {
|
|
log.Printf("Error reading private key: %v", err)
|
|
return fmt.Errorf("failed to read private key: %w", err)
|
|
}
|
|
|
|
signer, err := ssh.ParsePrivateKey(key)
|
|
if err != nil {
|
|
log.Printf("Error parsing private key: %v", err)
|
|
return fmt.Errorf("failed to parse private key: %w", err)
|
|
}
|
|
|
|
authMethods = append(authMethods, ssh.PublicKeys(signer))
|
|
}
|
|
|
|
// Add password-based authentication if a password is provided
|
|
if config.Password != "" {
|
|
log.Println("Using password authentication...")
|
|
authMethods = append(authMethods, ssh.Password(config.Password))
|
|
}
|
|
|
|
// Ensure at least one authentication method is configured
|
|
if len(authMethods) == 0 {
|
|
return fmt.Errorf("no authentication method provided (password or private key)")
|
|
}
|
|
|
|
// Configure SSH client
|
|
clientConfig := &ssh.ClientConfig{
|
|
User: config.User,
|
|
Auth: authMethods,
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // For production, use a proper host key callback!
|
|
}
|
|
|
|
address := fmt.Sprintf("%s:%s", config.Host, config.Port)
|
|
log.Printf("Connecting to SSH server at %s...", address)
|
|
|
|
// Connect to the SSH server
|
|
client, err := ssh.Dial("tcp", address, clientConfig)
|
|
if err != nil {
|
|
log.Printf("Failed to connect to SSH server: %v", err)
|
|
return fmt.Errorf("failed to connect to SSH server: %w", err)
|
|
}
|
|
defer client.Close()
|
|
|
|
log.Println("SSH connection established. Starting interactive session...")
|
|
return startInteractiveSession(client)
|
|
}
|
|
|
|
// startInteractiveSession starts an interactive shell session.
|
|
func startInteractiveSession(client *ssh.Client) error {
|
|
// Create a new session
|
|
session, err := client.NewSession()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create SSH session: %w", err)
|
|
}
|
|
defer session.Close()
|
|
|
|
// Set up terminal modes
|
|
modes := ssh.TerminalModes{
|
|
ssh.ECHO: 1, // Enable echoing
|
|
ssh.TTY_OP_ISPEED: 14400,
|
|
ssh.TTY_OP_OSPEED: 14400,
|
|
}
|
|
|
|
// Request a pseudo-terminal
|
|
if err := session.RequestPty("xterm", 80, 40, modes); err != nil {
|
|
return fmt.Errorf("failed to request pseudo-terminal: %w", err)
|
|
}
|
|
|
|
// Set up input/output for the session
|
|
session.Stdin = os.Stdin
|
|
session.Stdout = os.Stdout
|
|
session.Stderr = os.Stderr
|
|
|
|
// Start an interactive shell
|
|
if err := session.Shell(); err != nil {
|
|
return fmt.Errorf("failed to start shell: %w", err)
|
|
}
|
|
|
|
// Handle Ctrl+C and other interrupts to cleanly close the session
|
|
signalChan := make(chan os.Signal, 1)
|
|
signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM)
|
|
go func() {
|
|
<-signalChan
|
|
session.Close()
|
|
os.Exit(0)
|
|
}()
|
|
|
|
// Wait for the session to end
|
|
return session.Wait()
|
|
}
|