diff --git a/src/bin/psql/startup.c b/src/bin/psql/startup.c index 5d7fe6e..30af5b5 100644 --- a/src/bin/psql/startup.c +++ b/src/bin/psql/startup.c @@ -89,7 +89,9 @@ main(int argc, char *argv[]) int successResult; char *password = NULL; char *password_prompt = NULL; - bool new_pass; +#define PARAMS_ARRAY_SIZE 8 + const char **keywords = pg_malloc(PARAMS_ARRAY_SIZE * sizeof(*keywords)); + const char **values = pg_malloc(PARAMS_ARRAY_SIZE * sizeof(*values)); set_pglocale_pgservice(argv[0], PG_TEXTDOMAIN("psql")); @@ -197,50 +199,38 @@ main(int argc, char *argv[]) if (pset.getPassword == TRI_YES) password = simple_prompt(password_prompt, 100, false); - /* loop until we have a password if requested by backend */ - do - { -#define PARAMS_ARRAY_SIZE 8 - const char **keywords = pg_malloc(PARAMS_ARRAY_SIZE * sizeof(*keywords)); - const char **values = pg_malloc(PARAMS_ARRAY_SIZE * sizeof(*values)); - - keywords[0] = "host"; - values[0] = options.host; - keywords[1] = "port"; - values[1] = options.port; - keywords[2] = "user"; - values[2] = options.username; - keywords[3] = "password"; - values[3] = password; - keywords[4] = "dbname"; - values[4] = (options.action == ACT_LIST_DB && - options.dbname == NULL) ? - "postgres" : options.dbname; - keywords[5] = "fallback_application_name"; - values[5] = pset.progname; - keywords[6] = "client_encoding"; - values[6] = (pset.notty || getenv("PGCLIENTENCODING")) ? NULL : "auto"; - keywords[7] = NULL; - values[7] = NULL; - - new_pass = false; - pset.db = PQconnectdbParams(keywords, values, true); - free(keywords); - free(values); - - if (PQstatus(pset.db) == CONNECTION_BAD && - PQconnectionNeedsPassword(pset.db) && - password == NULL && - pset.getPassword != TRI_NO) - { - PQfinish(pset.db); - password = simple_prompt(password_prompt, 100, false); - new_pass = true; - } - } while (new_pass); - free(password); - free(password_prompt); + keywords[0] = "host"; + values[0] = options.host; + keywords[1] = "port"; + values[1] = options.port; + keywords[2] = "user"; + values[2] = options.username; + keywords[3] = "password"; + values[3] = password; + keywords[4] = "dbname"; + values[4] = (options.action == ACT_LIST_DB && + options.dbname == NULL) ? + "postgres" : options.dbname; + keywords[5] = "fallback_application_name"; + values[5] = pset.progname; + keywords[6] = "client_encoding"; + values[6] = (pset.notty || getenv("PGCLIENTENCODING")) ? NULL : "auto"; + keywords[7] = NULL; + values[7] = NULL; + + pset.db = PQconnectdbParams(keywords, values, true); + free(keywords); + free(values); + + if (PQstatus(pset.db) == CONNECTION_BAD && + PQconnectionNeedsPassword(pset.db) && + password == NULL && + pset.getPassword != TRI_NO) + { + password = simple_prompt(password_prompt, 100, false); + PQsendPassword(pset.db, password); + } if (PQstatus(pset.db) == CONNECTION_BAD) { diff --git a/src/interfaces/libpq/exports.txt b/src/interfaces/libpq/exports.txt index 93da50d..797b5d3 100644 --- a/src/interfaces/libpq/exports.txt +++ b/src/interfaces/libpq/exports.txt @@ -165,3 +165,4 @@ lo_lseek64 162 lo_tell64 163 lo_truncate64 164 PQconninfo 165 +PQsendPassword 166 diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c index ae9dfaa..402384f 100644 --- a/src/interfaces/libpq/fe-connect.c +++ b/src/interfaces/libpq/fe-connect.c @@ -466,6 +466,21 @@ PQconnectdbParams(const char *const * keywords, } /* + * PQsendPassword + * + * send a password that the server asked for halfway between a connection sequence. + */ +void +PQsendPassword(PGconn *conn, char *password) +{ + conn->pgpass = password; + conn->status = CONNECTION_SENDING_PASSWORD; + + resetPQExpBuffer(&conn->errorMessage); + (void) connectDBComplete(conn); +} + +/* * PQpingParams * * check server status, accepting parameters identical to PQconnectdbParams @@ -1555,6 +1570,7 @@ PQconnectPoll(PGconn *conn) PGresult *res; char sebuf[256]; int optval; + static AuthRequest areq; if (conn == NULL) return PGRES_POLLING_FAILED; @@ -1598,6 +1614,7 @@ PQconnectPoll(PGconn *conn) /* Special cases: proceed without waiting. */ case CONNECTION_SSL_STARTUP: case CONNECTION_NEEDED: + case CONNECTION_SENDING_PASSWORD: break; default: @@ -2160,7 +2177,6 @@ keep_going: /* We will come back to here until there is char beresp; int msgLength; int avail; - AuthRequest areq; /* * Scan the message from current point (note that if we find @@ -2442,7 +2458,34 @@ keep_going: /* We will come back to here until there is /* Look to see if we have more data yet. */ goto keep_going; } + case CONNECTION_SENDING_PASSWORD: + { + /* + * Note that conn->pghost must be non-NULL if we are going to + * avoid the Kerberos code doing a hostname look-up. + */ + + if (pg_fe_sendauth(areq, conn) != STATUS_OK) + { + conn->errorMessage.len = strlen(conn->errorMessage.data); + goto error_return; + } + conn->errorMessage.len = strlen(conn->errorMessage.data); + + /* + * Just make sure that any data sent by pg_fe_sendauth is + * flushed out. Although this theoretically could block, it + * really shouldn't since we don't send large auth responses. + */ + if (pqFlush(conn)) + goto error_return; + /* Now go back to reading backend's response to the password just sent + * in the current authentication sequence + */ + conn->status = CONNECTION_AWAITING_RESPONSE; + return PGRES_POLLING_READING; + } case CONNECTION_AUTH_OK: { /* diff --git a/src/interfaces/libpq/libpq-fe.h b/src/interfaces/libpq/libpq-fe.h index e0f4bc7..5888a58 100644 --- a/src/interfaces/libpq/libpq-fe.h +++ b/src/interfaces/libpq/libpq-fe.h @@ -62,7 +62,11 @@ typedef enum * backend startup. */ CONNECTION_SETENV, /* Negotiating environment. */ CONNECTION_SSL_STARTUP, /* Negotiating SSL. */ - CONNECTION_NEEDED /* Internal state: connect() needed */ + CONNECTION_NEEDED, /* Internal state: connect() needed */ + CONNECTION_SENDING_PASSWORD /* An intermediate state to help client send a password + * over an existing connection + */ + } ConnStatusType; typedef enum @@ -258,6 +262,9 @@ extern PGconn *PQsetdbLogin(const char *pghost, const char *pgport, #define PQsetdb(M_PGHOST,M_PGPORT,M_PGOPT,M_PGTTY,M_DBNAME) \ PQsetdbLogin(M_PGHOST, M_PGPORT, M_PGOPT, M_PGTTY, M_DBNAME, NULL, NULL) +/* send a password that the server asked for halfway between a connection sequence */ +extern void PQsendPassword(PGconn *conn, char *password); + /* close the current connection and free the PGconn data structure */ extern void PQfinish(PGconn *conn);