From 67923707ffd826ad1d02c0a5b5ebd75ffbc71364 Mon Sep 17 00:00:00 2001 From: Reto Brunner Date: Wed, 11 Nov 2020 23:50:35 +0100 Subject: Refactor send command --- commands/compose/send.go | 505 ++++++++++++++++++++++++++++------------------- 1 file changed, 304 insertions(+), 201 deletions(-) diff --git a/commands/compose/send.go b/commands/compose/send.go index e7ef509..40d0ae3 100644 --- a/commands/compose/send.go +++ b/commands/compose/send.go @@ -1,6 +1,7 @@ package compose import ( + "bytes" "crypto/tls" "fmt" "io" @@ -42,6 +43,7 @@ func (Send) Execute(aerc *widgets.Aerc, args []string) error { return errors.New("Usage: send") } composer, _ := aerc.SelectedTab().(*widgets.Composer) + tabName := aerc.TabNames()[aerc.SelectedTabIndex()] config := composer.Config() if config.Outgoing == "" { @@ -49,28 +51,6 @@ func (Send) Execute(aerc *widgets.Aerc, args []string) error { "No outgoing mail transport configured for this account") } - aerc.Logger().Println("Sending mail") - - uri, err := url.Parse(config.Outgoing) - if err != nil { - return errors.Wrap(err, "url.Parse(outgoing)") - } - var ( - scheme string - auth string = "plain" - ) - if uri.Scheme != "" { - parts := strings.Split(uri.Scheme, "+") - if len(parts) == 1 { - scheme = parts[0] - } else if len(parts) == 2 { - scheme = parts[0] - auth = parts[1] - } else { - return fmt.Errorf("Unknown transfer protocol %s", uri.Scheme) - } - } - header, err := composer.PrepareHeader() if err != nil { return errors.Wrap(err, "PrepareHeader") @@ -83,15 +63,187 @@ func (Send) Execute(aerc *widgets.Aerc, args []string) error { if config.From == "" { return errors.New("No 'From' configured for this account") } + // TODO: the user could conceivably want to use a different From and sender from, err := mail.ParseAddress(config.From) if err != nil { return errors.Wrap(err, "ParseAddress(config.From)") } - var ( - saslClient sasl.Client - conn *smtp.Client - ) + uri, err := url.Parse(config.Outgoing) + if err != nil { + return errors.Wrap(err, "url.Parse(outgoing)") + } + + scheme, auth, err := parseScheme(uri) + if err != nil { + return err + } + var starttls bool + if starttls_, ok := config.Params["smtp-starttls"]; ok { + starttls = starttls_ == "yes" + } + ctx := sendCtx{ + uri: uri, + scheme: scheme, + auth: auth, + starttls: starttls, + from: from, + rcpts: rcpts, + } + + var sender io.WriteCloser + switch ctx.scheme { + case "smtp": + fallthrough + case "smtps": + sender, err = newSmtpSender(ctx) + case "": + sender, err = newSendmailSender(ctx) + default: + sender, err = nil, fmt.Errorf("unsupported scheme %v", ctx.scheme) + } + if err != nil { + return errors.Wrap(err, "send:") + } + + // if we copy via the worker we need to know the count + counter := datacounter.NewWriterCounter(sender) + var writer io.Writer = counter + writer = counter + + var copyBuf bytes.Buffer + if config.CopyTo != "" { + writer = io.MultiWriter(writer, ©Buf) + } + + aerc.RemoveTab(composer) + aerc.PushStatus("Sending...", 10*time.Second) + + ch := make(chan error) + go func() { + err := composer.WriteMessage(header, writer) + if err != nil { + ch <- err + return + } + ch <- sender.Close() + }() + + // we don't want to block the UI thread while we are sending + go func() { + err = <-ch + if err != nil { + aerc.PushError(err.Error()) + aerc.NewTab(composer, tabName) + return + } + if config.CopyTo != "" { + aerc.PushStatus("Copying to "+config.CopyTo, 10*time.Second) + errCh := copyToSent(composer.Worker(), config.CopyTo, + int(counter.Count()), ©Buf) + err = <-errCh + if err != nil { + errmsg := fmt.Sprintf( + "message sent, but copying to %v failed: %v", + config.CopyTo, err.Error()) + aerc.PushError(errmsg) + composer.SetSent() + composer.Close() + return + } + } + aerc.PushStatus("Message sent.", 10*time.Second) + composer.SetSent() + composer.Close() + }() + return nil +} + +func listRecipients(h *mail.Header) ([]*mail.Address, error) { + var rcpts []*mail.Address + for _, key := range []string{"to", "cc", "bcc"} { + list, err := h.AddressList(key) + if err != nil { + return nil, err + } + rcpts = append(rcpts, list...) + } + return rcpts, nil +} + +type sendCtx struct { + uri *url.URL + scheme string + auth string + starttls bool + from *mail.Address + rcpts []*mail.Address +} + +func newSendmailSender(ctx sendCtx) (io.WriteCloser, error) { + args, err := shlex.Split(ctx.uri.Path) + if err != nil { + return nil, err + } + if len(args) == 0 { + return nil, fmt.Errorf("no command specified") + } + bin := args[0] + rs := make([]string, len(ctx.rcpts), len(ctx.rcpts)) + for i := range ctx.rcpts { + rs[i] = ctx.rcpts[i].Address + } + args = append(args[1:], rs...) + cmd := exec.Command(bin, args...) + s := &sendmailSender{cmd: cmd} + s.stdin, err = s.cmd.StdinPipe() + if err != nil { + return nil, errors.Wrap(err, "cmd.StdinPipe") + } + err = s.cmd.Start() + if err != nil { + return nil, errors.Wrap(err, "cmd.Start") + } + return s, nil +} + +type sendmailSender struct { + cmd *exec.Cmd + stdin io.WriteCloser +} + +func (s *sendmailSender) Write(p []byte) (int, error) { + return s.stdin.Write(p) +} + +func (s *sendmailSender) Close() error { + se := s.stdin.Close() + ce := s.cmd.Wait() + if se != nil { + return se + } + return ce +} + +func parseScheme(uri *url.URL) (scheme string, auth string, err error) { + scheme = "" + auth = "plain" + if uri.Scheme != "" { + parts := strings.Split(uri.Scheme, "+") + if len(parts) == 1 { + scheme = parts[0] + } else if len(parts) == 2 { + scheme = parts[0] + auth = parts[1] + } else { + return "", "", fmt.Errorf("Unknown transfer protocol %s", uri.Scheme) + } + } + return scheme, auth, nil +} + +func newSaslClient(auth string, uri *url.URL) (sasl.Client, error) { + var saslClient sasl.Client switch auth { case "": fallthrough @@ -105,7 +257,6 @@ func (Send) Execute(aerc *widgets.Aerc, args []string) error { saslClient = sasl.NewPlainClient("", uri.User.Username(), password) case "oauthbearer": q := uri.Query() - oauth2 := &oauth2.Config{} if q.Get("token_endpoint") != "" { oauth2.ClientID = q.Get("client_id") @@ -113,212 +264,164 @@ func (Send) Execute(aerc *widgets.Aerc, args []string) error { oauth2.Scopes = []string{q.Get("scope")} oauth2.Endpoint.TokenURL = q.Get("token_endpoint") } - password, _ := uri.User.Password() bearer := lib.OAuthBearer{ OAuth2: oauth2, Enabled: true, } if bearer.OAuth2.Endpoint.TokenURL == "" { - return fmt.Errorf("No 'TokenURL' configured for this account") + return nil, fmt.Errorf("No 'TokenURL' configured for this account") } token, err := bearer.ExchangeRefreshToken(password) if err != nil { - return err + return nil, err } password = token.AccessToken - saslClient = sasl.NewOAuthBearerClient(&sasl.OAuthBearerOptions{ Username: uri.User.Username(), Token: password, }) default: - return fmt.Errorf("Unsupported auth mechanism %s", auth) + return nil, fmt.Errorf("Unsupported auth mechanism %s", auth) } + return saslClient, nil +} - aerc.RemoveTab(composer) +type smtpSender struct { + ctx sendCtx + conn *smtp.Client + w io.WriteCloser +} - var starttls bool - if starttls_, ok := config.Params["smtp-starttls"]; ok { - starttls = starttls_ == "yes" +func (s *smtpSender) Write(p []byte) (int, error) { + return s.w.Write(p) +} + +func (s *smtpSender) Close() error { + we := s.w.Close() + ce := s.conn.Close() + if we != nil { + return we } + return ce +} - smtpAsync := func() (int, error) { - switch scheme { - case "smtp": - host := uri.Host - serverName := uri.Host - if !strings.ContainsRune(host, ':') { - host = host + ":587" // Default to submission port - } else { - serverName = host[:strings.IndexRune(host, ':')] - } - conn, err = smtp.Dial(host) - if err != nil { - return 0, errors.Wrap(err, "smtp.Dial") - } - defer conn.Close() - if sup, _ := conn.Extension("STARTTLS"); sup { - if !starttls { - err := errors.New("STARTTLS is supported by this server, " + - "but not set in accounts.conf. " + - "Add smtp-starttls=yes") - return 0, err - } - if err = conn.StartTLS(&tls.Config{ - ServerName: serverName, - }); err != nil { - return 0, errors.Wrap(err, "StartTLS") - } - } else { - if starttls { - err := errors.New("STARTTLS requested, but not supported " + - "by this SMTP server. Is someone tampering with your " + - "connection?") - return 0, err - } - } - case "smtps": - host := uri.Host - serverName := uri.Host - if !strings.ContainsRune(host, ':') { - host = host + ":465" // Default to smtps port - } else { - serverName = host[:strings.IndexRune(host, ':')] - } - conn, err = smtp.DialTLS(host, &tls.Config{ - ServerName: serverName, - }) - if err != nil { - return 0, errors.Wrap(err, "smtp.DialTLS") - } - defer conn.Close() - } +func newSmtpSender(ctx sendCtx) (io.WriteCloser, error) { + var ( + err error + conn *smtp.Client + ) + switch ctx.scheme { + case "smtp": + conn, err = connectSmtp(ctx.starttls, ctx.uri.Host) + case "smtps": + conn, err = connectSmtps(ctx.uri.Host) + default: + return nil, fmt.Errorf("not an smtp protocol %s", ctx.scheme) + } - if saslClient != nil { - if err = conn.Auth(saslClient); err != nil { - return 0, errors.Wrap(err, "conn.Auth") - } - } - // TODO: the user could conceivably want to use a different From and sender - if err = conn.Mail(from.Address, nil); err != nil { - return 0, errors.Wrap(err, "conn.Mail") - } - aerc.Logger().Printf("rcpt to: %v", rcpts) - for _, rcpt := range rcpts { - if err = conn.Rcpt(rcpt); err != nil { - return 0, errors.Wrap(err, "conn.Rcpt") - } + saslclient, err := newSaslClient(ctx.auth, ctx.uri) + if err != nil { + conn.Close() + return nil, err + } + if saslclient != nil { + if err := conn.Auth(saslclient); err != nil { + conn.Close() + return nil, errors.Wrap(err, "conn.Auth") } - wc, err := conn.Data() - if err != nil { - return 0, errors.Wrap(err, "conn.Data") + } + s := &smtpSender{ + ctx: ctx, + conn: conn, + } + if err := s.conn.Mail(s.ctx.from.Address, nil); err != nil { + conn.Close() + return nil, errors.Wrap(err, "conn.Mail") + } + for _, rcpt := range s.ctx.rcpts { + if err := s.conn.Rcpt(rcpt.Address); err != nil { + conn.Close() + return nil, errors.Wrap(err, "conn.Rcpt") } - defer wc.Close() - ctr := datacounter.NewWriterCounter(wc) - composer.WriteMessage(header, ctr) - return int(ctr.Count()), nil } + s.w, err = s.conn.Data() + if err != nil { + conn.Close() + return nil, errors.Wrap(err, "conn.Data") + } + return s.w, nil +} - sendmailAsync := func() (int, error) { - args, err := shlex.Split(uri.Path) - if err != nil { - return 0, err - } - if len(args) == 0 { - return 0, fmt.Errorf("no command specified") - } - bin := args[0] - args = append(args[1:], rcpts...) - cmd := exec.Command(bin, args...) - wc, err := cmd.StdinPipe() - if err != nil { - return 0, errors.Wrap(err, "cmd.StdinPipe") - } - err = cmd.Start() - if err != nil { - return 0, errors.Wrap(err, "cmd.Start") +func connectSmtp(starttls bool, host string) (*smtp.Client, error) { + serverName := host + if !strings.ContainsRune(host, ':') { + host = host + ":587" // Default to submission port + } else { + serverName = host[:strings.IndexRune(host, ':')] + } + conn, err := smtp.Dial(host) + if err != nil { + return nil, errors.Wrap(err, "smtp.Dial") + } + if sup, _ := conn.Extension("STARTTLS"); sup { + if !starttls { + err := errors.New("STARTTLS is supported by this server, " + + "but not set in accounts.conf. " + + "Add smtp-starttls=yes") + conn.Close() + return nil, err } - ctr := datacounter.NewWriterCounter(wc) - composer.WriteMessage(header, ctr) - wc.Close() // force close to make sendmail send - err = cmd.Wait() - if err != nil { - return 0, errors.Wrap(err, "cmd.Wait") + if err = conn.StartTLS(&tls.Config{ + ServerName: serverName, + }); err != nil { + conn.Close() + return nil, errors.Wrap(err, "StartTLS") } - return int(ctr.Count()), nil - } - - sendAsync := func() (int, error) { - fmt.Println(scheme) - switch scheme { - case "smtp": - fallthrough - case "smtps": - return smtpAsync() - case "": - return sendmailAsync() + } else { + if starttls { + err := errors.New("STARTTLS requested, but not supported " + + "by this SMTP server. Is someone tampering with your " + + "connection?") + conn.Close() + return nil, err } - return 0, errors.New("Unknown scheme") } + return conn, nil +} - go func() { - aerc.PushStatus("Sending...", 10*time.Second) - nbytes, err := sendAsync() - if err != nil { - aerc.PushError(" " + err.Error()) - return - } - if config.CopyTo != "" { - aerc.PushStatus("Copying to "+config.CopyTo, 10*time.Second) - worker := composer.Worker() - r, w := io.Pipe() - worker.PostAction(&types.AppendMessage{ - Destination: config.CopyTo, - Flags: []models.Flag{models.SeenFlag}, - Date: time.Now(), - Reader: r, - Length: nbytes, - }, func(msg types.WorkerMessage) { - switch msg := msg.(type) { - case *types.Done: - aerc.PushStatus("Message sent.", 10*time.Second) - r.Close() - composer.SetSent() - composer.Close() - case *types.Error: - aerc.PushError(" " + msg.Error.Error()) - r.Close() - composer.Close() - } - }) - header, err := composer.PrepareHeader() - if err != nil { - aerc.PushError(" " + err.Error()) - w.Close() - return - } - composer.WriteMessage(header, w) - w.Close() - } else { - aerc.PushStatus("Message sent.", 10*time.Second) - composer.SetSent() - composer.Close() - } - }() - return nil +func connectSmtps(host string) (*smtp.Client, error) { + serverName := host + if !strings.ContainsRune(host, ':') { + host = host + ":465" // Default to smtps port + } else { + serverName = host[:strings.IndexRune(host, ':')] + } + conn, err := smtp.DialTLS(host, &tls.Config{ + ServerName: serverName, + }) + if err != nil { + return nil, errors.Wrap(err, "smtp.DialTLS") + } + return conn, nil } -func listRecipients(h *mail.Header) ([]string, error) { - var rcpts []string - for _, key := range []string{"to", "cc", "bcc"} { - list, err := h.AddressList(key) - if err != nil { - return nil, err - } - for _, addr := range list { - rcpts = append(rcpts, addr.Address) +func copyToSent(worker *types.Worker, dest string, + n int, msg io.Reader) <-chan error { + errCh := make(chan error) + worker.PostAction(&types.AppendMessage{ + Destination: dest, + Flags: []models.Flag{models.SeenFlag}, + Date: time.Now(), + Reader: msg, + Length: n, + }, func(msg types.WorkerMessage) { + switch msg := msg.(type) { + case *types.Done: + errCh <- nil + case *types.Error: + errCh <- msg.Error } - } - return rcpts, nil + }) + return errCh } -- cgit v1.2.3