/* 
   HTTP session handling
   Copyright (C) 1999-2002, Joe Orton <joe@manyfish.co.uk>
   Portions are:
   Copyright (C) 1999-2000 Tommi Komulainen <Tommi.Komulainen@iki.fi>

   This library is free software; you can redistribute it and/or
   modify it under the terms of the GNU Library General Public
   License as published by the Free Software Foundation; either
   version 2 of the License, or (at your option) any later version.
   
   This library is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
   Library General Public License for more details.

   You should have received a copy of the GNU Library General Public
   License along with this library; if not, write to the Free
   Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
   MA 02111-1307, USA

*/

#include "config.h"

#ifdef HAVE_STRING_H
#include <string.h>
#endif
#ifdef HAVE_STDLIB_H
#include <stdlib.h>
#endif
#ifdef HAVE_ERRNO_H
#include <errno.h>
#endif

#ifdef NEON_SSL
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <openssl/pkcs12.h>
#endif

#include "ne_session.h"
#include "ne_alloc.h"
#include "ne_utils.h"
#include "ne_i18n.h"

#include "ne_private.h"

#define NEON_USERAGENT "neon/" NEON_VERSION;

#ifdef NEON_SSL
static int provide_client_cert(SSL *ssl, X509 **cert, EVP_PKEY **pkey);
#endif

/* Destroy a a list of hooks. */
static void destroy_hooks(struct hook *hooks)
{
    struct hook *nexthk;

    while (hooks) {
	nexthk = hooks->next;
	free(hooks);
	hooks = nexthk;
    }
}

void ne_session_destroy(ne_session *sess) 
{
    struct hook *hk;

    NE_DEBUG(NE_DBG_HTTP, "ne_session_destroy called.\n");

    /* Run the destroy hooks. */
    for (hk = sess->destroy_sess_hooks; hk != NULL; hk = hk->next) {
	ne_destory_fn fn = (ne_destory_fn)hk->fn;
	fn(hk->userdata);
    }
    
    destroy_hooks(sess->create_req_hooks);
    destroy_hooks(sess->pre_send_hooks);
    destroy_hooks(sess->post_send_hooks);
    destroy_hooks(sess->destroy_req_hooks);
    destroy_hooks(sess->destroy_sess_hooks);
    destroy_hooks(sess->accessor_hooks);

    NE_FREE(sess->server.hostname);
    NE_FREE(sess->server.hostport);
    NE_FREE(sess->scheme);
    NE_FREE(sess->proxy.hostport);
    NE_FREE(sess->user_agent);

    if (sess->connected) {
	ne_close_connection(sess);
    }

#ifdef NEON_SSL
    if (sess->ssl_context)
	SSL_CTX_free(sess->ssl_context);

    if (sess->ssl_sess)
	SSL_SESSION_free(sess->ssl_sess);

    if (sess->server_cert)
	X509_free(sess->server_cert);
#endif

    free(sess);
}

int ne_version_pre_http11(ne_session *s)
{
    return VERSION_PRE11(s);
}

/* Returns the "hostname[:port]" segment */
/* FIXME: for SSL. */
static char *get_hostport(struct host_info *host) 
{
    size_t len = strlen(host->hostname);
    char *ret = ne_malloc(len + 10);
    strcpy(ret, host->hostname);
    if (host->port != 80) {
	ne_snprintf(ret + len, 9, ":%d", host->port);
    }
    return ret;
}

/* Stores the hostname/port in *info, setting up the "hostport"
 * segment correctly. */
static void
set_hostinfo(struct host_info *info, const char *hostname, int port)
{
    NE_FREE(info->hostport);
    NE_FREE(info->hostname);
    info->hostname = ne_strdup(hostname);
    info->port = port;
    info->hostport = get_hostport(info);
}

ne_session *ne_session_create(const char *scheme,
			      const char *hostname, int port)
{
    ne_session *sess = ne_calloc(sizeof *sess);

    NE_DEBUG(NE_DBG_HTTP, "HTTP session to %s://%s:%d begins.\n",
	     scheme, hostname, port);

    strcpy(sess->error, "Unknown error.");
    sess->version_major = -1;
    sess->version_minor = -1;

    /* set the hostname/port */
    set_hostinfo(&sess->server, hostname, port);
    
    /* use SSL if scheme is https */
    sess->use_ssl = !strcmp(scheme, "https");

#ifdef NEON_SSL
    if (sess->use_ssl) {
	sess->ssl_context = SSL_CTX_new(SSLv23_client_method());
	/* set client cert callback. */
	SSL_CTX_set_client_cert_cb(sess->ssl_context, provide_client_cert);
    }
#endif

    sess->scheme = ne_strdup(scheme);

    /* Default expect-100 to OFF. */
    sess->expect100_works = -1;
    return sess;
}

void ne_session_proxy(ne_session *sess, const char *hostname, int port)
{
    sess->use_proxy = 1;
    set_hostinfo(&sess->proxy, hostname, port);
}

void ne_set_error(ne_session *sess, const char *format, ...)
{
    va_list params;

    va_start(params, format);
    ne_vsnprintf(sess->error, BUFSIZ, format, params);
    va_end(params);
}


void ne_set_progress(ne_session *sess, 
		       sock_progress progress, void *userdata)
{
    sess->progress_cb = progress;
    sess->progress_ud = userdata;
}

void ne_set_status(ne_session *sess,
		     ne_notify_status status, void *userdata)
{
    sess->notify_cb = status;
    sess->notify_ud = userdata;
}

void ne_set_expect100(ne_session *sess, int use_expect100)
{
    if (use_expect100) {
	sess->expect100_works = 1;
    } else {
	sess->expect100_works = -1;
    }
}

void ne_set_persist(ne_session *sess, int persist)
{
    sess->no_persist = !persist;
}

void ne_set_read_timeout(ne_session *sess, int timeout)
{
    sess->rdtimeout = timeout;
}

void ne_set_useragent(ne_session *sess, const char *token)
{
    static const char *fixed = " " NEON_USERAGENT;
    NE_FREE(sess->user_agent);
    CONCAT2(sess->user_agent, token, fixed);
}

const char *ne_get_server_hostport(ne_session *sess)
{
    return sess->server.hostport;
}

const char *ne_get_scheme(ne_session *sess)
{
    return sess->scheme;
}

const char *ne_get_error(ne_session *sess)
{
    return sess->error;
}

int ne_close_connection(ne_session *sess)
{
    NE_DEBUG(NE_DBG_SOCKET, "Closing connection.\n");
    if (sess->connected > 0) {
	sock_close(sess->socket);
	sess->socket = NULL;
    }
    sess->connected = 0;
    NE_DEBUG(NE_DBG_SOCKET, "Connection closed.\n");
    return 0;
}

void ne_ssl_set_verify(ne_session *sess, ne_ssl_verify_fn fn, void *userdata)
{
    sess->ssl_verify_fn = fn;
    sess->ssl_verify_ud = userdata;
}

void ne_ssl_provide_ccert(ne_session *sess, 
			  ne_ssl_provide_fn fn, void *userdata)
{
    sess->ssl_provide_fn = fn;
    sess->ssl_provide_ud = userdata;
}

#ifdef NEON_SSL

SSL_CTX *ne_ssl_get_context(ne_session *sess)
{
    sess->ssl_context->references++;
    return sess->ssl_context;
}

/* Map a server cert verification into a string. */
static void verify_err(ne_session *sess, int failures)
{
    struct {
	int bit;
	const char *str;
    } reasons[] = {
	{ NE_SSL_NOTYETVALID, N_("not yet valid") },
	{ NE_SSL_EXPIRED, N_("Server certificate has expired") },
	{ NE_SSL_CNMISMATCH, N_("Certificate hostname mismatch") },
	{ NE_SSL_UNKNOWNCA, N_("issuer not trusted") },
	{ 0, NULL }
    };
    int n, flag = 0;

    strcpy(sess->error, _("Server certificate verification failed: "));

    for (n = 0; reasons[n].bit; n++) {
	if (failures & reasons[n].bit) {
	    if (flag) strncat(sess->error, ", ", BUFSIZ);
	    strncat(sess->error, _(reasons[n].str), BUFSIZ);
	    flag = 1;
	}
    }

}

/* enough to store a single field. */
#define ATTBUFSIZ (1028)

/* Get an attribute out of 'name', using 'dump' as a temporary
 * buffer. */
static const char *getx509field(X509_NAME *name, int nid, 
				ne_buffer *dump)
{
    char *buf;
    int ret;
    
    /* make sure we have 1K of space. */
    ne_buffer_grow(dump, dump->used + ATTBUFSIZ);
    buf = dump->data + dump->used;
    ret = X509_NAME_get_text_by_NID(name, nid, buf, ATTBUFSIZ);
    if (ret < 1) {
	return NULL;
    } else {
	dump->used += (size_t)ret + 1; /* +1 for \0 */
	return buf;
    }
}

/* Return malloc-allocated string representatino of given ASN.1 time
 * structure.  TODO: would be better to parse out the raw ASN.1 string
 * and give that to the application in some form which is localisable
 * e.g. time_t.  */
static char *asn1time_to_string(ASN1_TIME *tm)
{
  char buf[64];
  BIO *bio;
  
  strncpy(buf, _("[invalid date]"), sizeof(buf));

  bio = BIO_new(BIO_s_mem());
  if (bio) {
      if (ASN1_TIME_print(bio, tm))
	  BIO_read(bio, buf, sizeof(buf));
      BIO_free(bio);
  }
  
  return ne_strdup(buf);
}

/* Return non-zero if hostname from certificate (cn) doesn't match
 * hostname used for session (hostname). TODO: could do more advanced
 * wildcard matching using fnmatch() here, if fnmatch is present. */
static int match_hostname(char *cn, const char *hostname)
{
    const char *dot = strchr(hostname, '.');
    if (dot == NULL) {
	char *pnt = strchr(cn, '.');
	/* hostname is not fully-qualified; unqualify the cn. */
	if (pnt != NULL) {
	    *pnt = '\0';
	}
    }
    else if (strncmp(cn, "*.", 2) == 0) {
	hostname = dot + 1;
	cn += 2;
    }
    return strcmp(cn, hostname);
}

/* Fills in the friendly DN structure 'dn' from given X509 name 'xn',
 * using 'dump' as temporary storage. */
static void make_dname(ne_ssl_dname *dn, X509_NAME *xn, ne_buffer *dump)
{
    dn->country = getx509field(xn, NID_countryName, dump);
    dn->state = getx509field(xn, NID_stateOrProvinceName, dump);
    dn->locality = getx509field(xn, NID_localityName, dump);
    dn->organization = getx509field(xn, NID_organizationName, dump);
    dn->organizationalUnit = getx509field(xn, 
					  NID_organizationalUnitName, dump);
    dn->commonName = getx509field(xn, NID_commonName, dump);
}

/* Verifies an SSL server certificate. */
static int check_certificate(ne_session *sess, SSL *ssl, X509 *cert)
{
    X509_NAME *subj = X509_get_subject_name(cert);
    X509_NAME *issuer = X509_get_issuer_name(cert);
    ASN1_TIME *notBefore = X509_get_notBefore(cert);
    ASN1_TIME *notAfter = X509_get_notAfter(cert);
    char buf[ATTBUFSIZ];
    int ret, failures = 0;
    long result;

    /* check expiry dates */
    if (X509_cmp_current_time(notBefore) >= 0) {
	failures |= NE_SSL_NOTYETVALID;
    }
    else if (X509_cmp_current_time(notAfter) <= 0) {
	failures |= NE_SSL_EXPIRED;
    } 

    /* retrieve the commonName, compare with the server hostname. */
    ret = X509_NAME_get_text_by_NID(subj, NID_commonName, buf, ATTBUFSIZ);
    if (ret < 1) {
	ne_set_error(sess, 
		     _("Server certificate was missing commonName attribute"));
	return NE_ERROR;
    }
    
    if (match_hostname(buf, sess->server.hostname)) {
	failures |= NE_SSL_CNMISMATCH;
    }

    /* get the result of the cert verication out of OpenSSL */
    result = SSL_get_verify_result(ssl);

    NE_DEBUG(NE_DBG_HTTP, "Verify result: %ld = %s\n", result,
	     X509_verify_cert_error_string(result));

#if NE_DEBUGGING
    if (ne_debug_mask & NE_DBG_HTTP)
	X509_print_fp(ne_debug_stream, cert);
#endif

    switch (result) {
    case X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY:
    case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN:
    case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT:
	/* TODO: and probably more result codes here... */
	failures |= NE_SSL_UNKNOWNCA;
	break;
    case X509_V_ERR_CERT_NOT_YET_VALID:
    case X509_V_ERR_CERT_HAS_EXPIRED:
	/* ignore these, since we've already noticed them . */
	break;
    case X509_V_OK:
	/* it's okay. */
	break;
    default:
	/* TODO: tricky to handle the 30-odd failure cases OpenSSL
	 * presents here (see x509_vfy.h), and present a useful API to
	 * the application so it in turn can then present a meaningful
	 * UI to the user.  The only thing to do really would be to
	 * pass back the error string, but that's not localisable.  So
	 * just fail the verification here - better safe than
	 * sorry. */
	ne_set_error(sess, _("Certificate verification error: %s"),
		     X509_verify_cert_error_string(result));
	return NE_ERROR;
    }

    if (sess->ssl_verify_fn && failures) {
	ne_ssl_certificate c;
	ne_ssl_dname sdn = {0}, idn = {0};
	ne_buffer *dump = ne_buffer_create_sized(ATTBUFSIZ * 2);
	char *from, *until;
	
	/* Do the gymnatics to retrieve attributes out of the
	 * X509_NAME, store them in a temporary buffer (dump), and set
	 * the structure fields up to pass to the verify callback.
	 * Using a temp buffer means that this can be done with only a
	 * few malloc() calls and only one free(). */

	dump->used = 0; /* ignore the initial \0 */

	make_dname(&sdn, subj, dump);
	make_dname(&idn, issuer, dump);
	
	c.subject = &sdn;
	c.issuer = &idn;
	c.from = from = asn1time_to_string(notBefore);
	c.until = until = asn1time_to_string(notAfter);

	if (sess->ssl_verify_fn(sess->ssl_verify_ud, failures, &c)) {
	    ne_set_error(sess, _("Certificate verification failed"));
	    ret = NE_ERROR;
	} else {
	    ret = NE_OK;
	}

	ne_buffer_destroy(dump);
	free(from);
	free(until);

    } else if (failures != 0) {
	verify_err(sess, failures);
	ret = NE_ERROR;
    } else {
	/* well, okay then if you insist. */
	ret = NE_OK;
    }

    return ret;
}

/* Callback invoked when the SSL server requests a client certificate.  */
static int provide_client_cert(SSL *ssl, X509 **cert, EVP_PKEY **pkey)
{
    ne_session *sess = SSL_get_app_data(ssl);

    if (!sess->client_key && sess->ssl_provide_fn) {
	ne_ssl_dname dn;
	ne_buffer *buf = ne_buffer_create_sized(2048);
	X509 *peer = SSL_get_peer_certificate(ssl);
	X509_NAME *subject = peer?X509_get_subject_name(peer):NULL;
	
	if (subject) {
	    make_dname(&dn, subject, buf);
	}

	NE_DEBUG(NE_DBG_HTTP, "Calling client certificate provider...\n");
	sess->ssl_provide_fn(sess->ssl_provide_ud, sess, &dn);
	ne_buffer_destroy(buf);			     
    }

    if (sess->client_key && sess->client_cert) {
	NE_DEBUG(NE_DBG_HTTP, "Supplying client certificate.");
	*cert = sess->client_cert;
	*pkey = sess->client_key;
	return 1;
    } else {
	NE_DEBUG(NE_DBG_HTTP, "No client certificate supplied.\n");
	return 0;
    }
}

/* For internal use only. */
int ne_negotiate_ssl(ne_request *req)
{
    ne_session *sess = req->session;
    SSL *ssl;
    X509 *cert;

    NE_DEBUG(NE_DBG_HTTP, "Doing SSL negotiation.\n");

    if (sock_enable_ssl_os(sess->socket, sess->ssl_context, 
			   sess->ssl_sess, &ssl, sess)) {
	NE_DEBUG(NE_DBG_HTTP, "sock_eso failed: %s\n", 
		 sock_get_error(sess->socket));
	if (sess->ssl_sess) {
	    /* remove cached session. */
	    SSL_SESSION_free(sess->ssl_sess);
	    sess->ssl_sess = NULL;
	}
	ne_set_error(sess, _("SSL negotiation failed"));
	return NE_ERROR;
    }	
    
    cert = SSL_get_peer_certificate(ssl);
    if (cert == NULL) {
	ne_set_error(sess, _("SSL server did not present certificate"));
	return NE_ERROR;
    }

    if (sess->server_cert) {
	if (X509_cmp(cert, sess->server_cert)) {
	    /* This could be a MITM attack: fail the request. */
	    ne_set_error(sess, _("Server certificated changed: connection intercepted?"));
	    X509_free(sess->server_cert);
	    sess->server_cert = NULL;
	    return NE_ERROR;
	} 
	/* server_cert is the same as cert, so abandon new copy. */
	X509_free(cert);
	/* certificate has already passed verification: no need to
	 * verify it again. */
    } else {
	/* new connection: verify the cert. */
	if (check_certificate(sess, ssl, cert)) {
	    NE_DEBUG(NE_DBG_HTTP, "SSL certificate checks failed: %s\n",
		     sess->error);
	    return NE_ERROR;
	}
	/* cache the cert. */
	sess->server_cert = cert;
    }
    
    if (!sess->ssl_sess) {
	/* store the session. */
	sess->ssl_sess = SSL_get1_session(ssl);
    }

    if (sess->notify_cb) {
	sess->notify_cb(sess->notify_ud, ne_conn_secure, SSL_get_version(ssl));
    }

    return NE_OK;
}

int ne_ssl_load_ca(ne_session *sess, const char *file)
{
    return !SSL_CTX_load_verify_locations(sess->ssl_context, file, NULL);
}

int ne_ssl_load_default_ca(ne_session *sess)
{
    return !SSL_CTX_set_default_verify_paths(sess->ssl_context);
}

static int privkey_prompt(char *buf, int len, int rwflag, void *userdata)
{
    ne_session *sess = userdata;
    
    if (sess->ssl_keypw_fn(sess->ssl_keypw_ud, buf, len))
	return -1;

    /* Obscurely OpenSSL requires the callback to return the length of
     * the password, this seems a bit weird so we don't expose this in
     * the neon API. */
    return strlen(buf);
}

void ne_ssl_keypw_prompt(ne_session *sess, ne_ssl_keypw_fn fn, void *ud)
{
    /* install our wrapper callback. */
    SSL_CTX_set_default_passwd_cb(sess->ssl_context, privkey_prompt);
    SSL_CTX_set_default_passwd_cb_userdata(sess->ssl_context, sess);
    /* and remember the application callback. */
    sess->ssl_keypw_fn = fn;
    sess->ssl_keypw_ud = ud;
}

int ne_ssl_load_pkcs12(ne_session *sess, const char *fn)
{
    /* you are lost in a maze of twisty crypto algorithms... */
    PKCS12 *p12;
    FILE *fp;
    int ret;
    char *password = NULL, buf[BUFSIZ];

    fp = fopen(fn, "r");
    if (fp == NULL) {
	ne_set_error(sess, _("Could not open file `%s': %s"), fn,
		     strerror(errno));
	return -1;
    }		     

    p12 = d2i_PKCS12_fp(fp, NULL);

    fclose(fp);
    
    if (p12 == NULL) {
	ne_set_error(sess, _("Could not read certificate from file `%s'"),
		     fn);
	return -1;
    }

    sess->client_key = NULL;
    sess->client_cert = NULL;

    if (sess->ssl_keypw_fn) {
	if (sess->ssl_keypw_fn(sess->ssl_keypw_ud, buf, BUFSIZ) == 0)
	    password = buf;
    }
	
    ret = PKCS12_parse(p12, password, 
		       &sess->client_key, &sess->client_cert, NULL);
    PKCS12_free(p12);

    if (ret != 1) {
	ne_set_error(sess,
		     _("Error parsing certificate (incorrect password?): %s"),
		     ERR_reason_error_string(ERR_get_error()));
	return -1;
    }

    return 0;
}

int ne_ssl_load_pem(ne_session *sess, const char *cert, const char *key)
{
    FILE *fp;

    sess->client_key = NULL;
    sess->client_cert = NULL;

    fp = fopen(cert, "r");
    if (fp == NULL) {
	ne_set_error(sess, _("Could not open file `%s': %s"), cert,
		     strerror(errno));
	return -1;
    }

    sess->client_cert = PEM_read_X509(fp, NULL, NULL, NULL);
    if (sess->client_cert == NULL) {
	ne_set_error(sess, _("Could not read certificate"));
	fclose(fp);
	return -1;
    }

    if (key != NULL) {
	fclose(fp);
	fp = fopen(key, "r");
	if (fp == NULL) {
	    ne_set_error(sess, 
			 _("Could not open private key file `%s': %s"),
			 key, strerror(errno));
	    return -1;
	}
    }

    sess->client_key = PEM_read_PrivateKey(fp, NULL, NULL, NULL);
    if (sess->client_key == NULL) {
	ne_set_error(sess, 
		     _("Could not parse private key (incorrect password?)"));
	return -1;
    }
    
    return 0;
}

X509 *ne_ssl_server_cert(ne_session *sess)
{
    return sess->server_cert;
}

#else

#define STUB(sess) ne_set_error(sess, _("SSL is not supported")); return NE_ERROR

/* Stubs to make the library have the same ABI whether or not SSL
 * support is enabled. */
int ne_negotiate_ssl(ne_request *req) { STUB(req->session); }
int ne_ssl_load_ca(ne_session *sess, const char *file) { STUB(sess); }
int ne_ssl_load_default_ca(ne_session *sess) { STUB(sess); }
int ne_ssl_load_pkcs12(ne_session *sess, const char *fn) { STUB(sess); }
int ne_ssl_load_pem(ne_session *sess, const char *cert, const char *key) { STUB(sess); }
void ne_ssl_keypw_prompt(ne_session *sess, ne_ssl_keypw_fn fn, void *ud) {}

#endif


