diff --git a/ssh/connect.go b/ssh/connect.go index 8c0ffb5..8d3efed 100644 --- a/ssh/connect.go +++ b/ssh/connect.go @@ -2,8 +2,10 @@ package ssh import ( "fmt" - "io/ioutil" "log" + "os" + "os/signal" + "syscall" "golang.org/x/crypto/ssh" ) @@ -16,7 +18,7 @@ type SSHConfig struct { PrivateKey string } -// ConnectSSH handles both key-based and password authentication dynamically. +// ConnectSSH establishes an SSH connection and starts an interactive session. func ConnectSSH(config SSHConfig) error { log.Printf("Attempting SSH connection to %s@%s:%s", config.User, config.Host, config.Port) @@ -25,8 +27,8 @@ func ConnectSSH(config SSHConfig) error { // Add key-based authentication if a private key is provided if config.PrivateKey != "" { - log.Println("Attempting key-based authentication...") - key, err := ioutil.ReadFile(config.PrivateKey) + log.Println("Using private key authentication...") + key, err := os.ReadFile(config.PrivateKey) if err != nil { log.Printf("Error reading private key: %v", err) return fmt.Errorf("failed to read private key: %w", err) @@ -43,7 +45,7 @@ func ConnectSSH(config SSHConfig) error { // Add password-based authentication if a password is provided if config.Password != "" { - log.Println("Attempting password-based authentication...") + log.Println("Using password authentication...") authMethods = append(authMethods, ssh.Password(config.Password)) } @@ -56,13 +58,13 @@ func ConnectSSH(config SSHConfig) error { clientConfig := &ssh.ClientConfig{ User: config.User, Auth: authMethods, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), // Use a proper host key callback for production! + HostKeyCallback: ssh.InsecureIgnoreHostKey(), // For production, use a proper host key callback! } - // Dial the SSH server address := fmt.Sprintf("%s:%s", config.Host, config.Port) log.Printf("Connecting to SSH server at %s...", address) + // Connect to the SSH server client, err := ssh.Dial("tcp", address, clientConfig) if err != nil { log.Printf("Failed to connect to SSH server: %v", err) @@ -70,6 +72,50 @@ func ConnectSSH(config SSHConfig) error { } defer client.Close() - log.Printf("Successfully connected to %s@%s:%s", config.User, config.Host, config.Port) - return nil + log.Println("SSH connection established. Starting interactive session...") + return startInteractiveSession(client) +} + +// startInteractiveSession starts an interactive shell session. +func startInteractiveSession(client *ssh.Client) error { + // Create a new session + session, err := client.NewSession() + if err != nil { + return fmt.Errorf("failed to create SSH session: %w", err) + } + defer session.Close() + + // Set up terminal modes + modes := ssh.TerminalModes{ + ssh.ECHO: 1, // Enable echoing + ssh.TTY_OP_ISPEED: 14400, + ssh.TTY_OP_OSPEED: 14400, + } + + // Request a pseudo-terminal + if err := session.RequestPty("xterm", 80, 40, modes); err != nil { + return fmt.Errorf("failed to request pseudo-terminal: %w", err) + } + + // Set up input/output for the session + session.Stdin = os.Stdin + session.Stdout = os.Stdout + session.Stderr = os.Stderr + + // Start an interactive shell + if err := session.Shell(); err != nil { + return fmt.Errorf("failed to start shell: %w", err) + } + + // Handle Ctrl+C and other interrupts to cleanly close the session + signalChan := make(chan os.Signal, 1) + signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM) + go func() { + <-signalChan + session.Close() + os.Exit(0) + }() + + // Wait for the session to end + return session.Wait() } diff --git a/tui/interface.go b/tui/interface.go index 08e1978..d74e88c 100644 --- a/tui/interface.go +++ b/tui/interface.go @@ -1,27 +1,83 @@ package tui import ( + "log" + "tui-ssh-app/ssh" + "github.com/rivo/tview" ) func StartTUI() { app := tview.NewApplication() + // Input fields for SSH configuration + userField := tview.NewInputField().SetLabel("User: ") + hostField := tview.NewInputField().SetLabel("Host: ") + portField := tview.NewInputField().SetLabel("Port: ").SetText("22") + passwordField := tview.NewInputField().SetLabel("Password: ") + privateKeyField := tview.NewInputField().SetLabel("Private Key Path: ") + + // Output box for connection status + outputBox := tview.NewTextView().SetDynamicColors(true) + + // Layout form := tview.NewForm(). - AddInputField("User", "", 20, nil, nil). - AddInputField("Host", "", 20, nil, nil). - AddInputField("Port", "22", 5, nil, nil). - AddPasswordField("Password", "", 20, '*', nil). - AddInputField("Private Key", "", 20, nil, nil). + AddFormItem(userField). + AddFormItem(hostField). + AddFormItem(portField). + AddFormItem(passwordField). + AddFormItem(privateKeyField). AddButton("Connect", func() { - // Handle SSH connection logic + // Read inputs + user := userField.GetText() + host := hostField.GetText() + port := portField.GetText() + password := passwordField.GetText() + privateKey := privateKeyField.GetText() + + // Validate inputs + if user == "" || host == "" || port == "" { + outputBox.SetText("[red]Error: All fields except Password/Private Key are required!").ScrollToEnd() + return + } + + // Display connection status + outputBox.SetText("[yellow]Connecting...").ScrollToEnd() + + // Perform SSH connection in a goroutine to avoid blocking the TUI + go func() { + err := ssh.ConnectSSH(ssh.SSHConfig{ + User: user, + Host: host, + Port: port, + Password: password, + PrivateKey: privateKey, + }) + + if err != nil { + log.Printf("SSH connection failed: %v", err) + app.QueueUpdateDraw(func() { + outputBox.SetText("[red]Connection failed: " + err.Error()).ScrollToEnd() + }) + return + } + + app.QueueUpdateDraw(func() { + outputBox.SetText("[green]Connection successful!").ScrollToEnd() + }) + }() }). AddButton("Quit", func() { app.Stop() }) - form.SetBorder(true).SetTitle("SSH Connection").SetTitleAlign(tview.AlignLeft) - if err := app.SetRoot(form, true).Run(); err != nil { - panic(err) + layout := tview.NewFlex(). + SetDirection(tview.FlexRow). + AddItem(form, 0, 1, true). + AddItem(outputBox, 0, 1, false) + + // Run the application + if err := app.SetRoot(layout, true).Run(); err != nil { + log.Fatalf("Error running application: %v", err) } }