/* -*- Mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */

/***********************************************************************
**  1996 - Netscape Communications Corporation
**
** Name: accept.c
**
** Description: Run accept() sucessful connection tests.
**
** Modification History:
** 04-Jun-97 AGarcia - Reconvert test file to return a 0 for PASS and a 1 for FAIL
** 13-May-97 AGarcia- Converted the test to accomodate the debug_mode
**             The debug mode will print all of the printfs associated with this test.
**             The regress mode will be the default mode. Since the regress tool limits
**           the output to a one line status:PASS or FAIL,all of the printf statements
**             have been handled with an if (debug_mode) statement.
** 04-June-97 AGarcia removed the Test_Result function. Regress tool has been updated to
**            recognize the return code from tha main program.
** 12-June-97 Revert to return code 0 and 1.
***********************************************************************/

/***********************************************************************
** Includes
***********************************************************************/

#include "nspr.h"
#include "prpriv.h"

#include <stdlib.h>
#include <string.h>

#include "plgetopt.h"
#include "plerror.h"

#define BASE_PORT 10000

#define CLIENT_DATA        128

#define ACCEPT_NORMAL        0x1
#define ACCEPT_FAST        0x2
#define ACCEPT_READ        0x3
#define ACCEPT_READ_FAST    0x4
#define ACCEPT_READ_FAST_CB    0x5

#define CLIENT_NORMAL        0x1
#define CLIENT_TIMEOUT_ACCEPT    0x2
#define CLIENT_TIMEOUT_SEND    0x3

#define SERVER_MAX_BIND_COUNT        100

#if defined(XP_OS2)
#define TIMEOUTSECS 10
#else
#define TIMEOUTSECS 2
#endif
PRIntervalTime timeoutTime;

static PRInt32 count = 1;
static PRFileDesc *output;
static PRNetAddr serverAddr;
static PRThreadScope thread_scope = PR_LOCAL_THREAD;
static PRInt32 clientCommand;
static PRInt32 iterations;
static PRStatus rv;
static PRFileDesc *listenSock;
static PRFileDesc *clientSock = NULL;
static PRNetAddr listenAddr;
static PRNetAddr clientAddr;
static PRThread *clientThread;
static PRNetAddr *raddr;
static char buf[4096 + 2*sizeof(PRNetAddr) + 32];
static PRInt32 status;
static PRInt32 bytesRead;

PRIntn failed_already=0;
PRIntn debug_mode;

void Test_Assert(const char *msg, const char *file, PRIntn line)
{
    failed_already=1;
    if (debug_mode) {
        PR_fprintf(output,  "@%s:%d ", file, line);
        PR_fprintf(output, msg);
    }
}  /* Test_Assert */

#define TEST_ASSERT(expr) \
    if (!(expr)) Test_Assert(#expr, __FILE__, __LINE__)

#ifdef WINNT
#define CALLBACK_MAGIC 0x12345678

void timeout_callback(void *magic)
{
    TEST_ASSERT(magic == (void *)CALLBACK_MAGIC);
    if (debug_mode) {
        PR_fprintf(output, "timeout callback called okay\n");
    }
}
#endif


static void PR_CALLBACK
ClientThread(void *_action)
{
    PRInt32 action = * (PRInt32 *) _action;
    PRInt32 iterations = count;
    PRFileDesc *sock = NULL;

    serverAddr.inet.family = PR_AF_INET;
    serverAddr.inet.port = listenAddr.inet.port;
    serverAddr.inet.ip = PR_htonl(PR_INADDR_LOOPBACK);

    for (; iterations--;) {
        PRInt32 rv;
        char buf[CLIENT_DATA];

        memset(buf, 0xaf, sizeof(buf)); /* initialize with arbitrary data */
        sock = PR_NewTCPSocket();
        if (!sock) {
            if (!debug_mode) {
                failed_already=1;
            }
            else {
                PR_fprintf(output, "client: unable to create socket\n");
            }
            return;
        }

        if (action != CLIENT_TIMEOUT_ACCEPT) {

            if ((rv = PR_Connect(sock, &serverAddr,
                                 timeoutTime)) < 0) {
                if (!debug_mode) {
                    failed_already=1;
                }
                else
                    PR_fprintf(output,
                               "client: unable to connect to server (%ld, %ld, %ld, %ld)\n",
                               iterations, rv, PR_GetError(), PR_GetOSError());
                goto ErrorExit;
            }

            if (action != CLIENT_TIMEOUT_SEND) {
                if ((rv = PR_Send(sock, buf, CLIENT_DATA,
                                  0, timeoutTime))< 0) {
                    if (!debug_mode) {
                        failed_already=1;
                    } else {
                        PR_fprintf(output,
                                   "client: unable to send to server (%d, %ld, %ld)\n",
                                   CLIENT_DATA, rv, PR_GetError());
                    }
                    goto ErrorExit;
                }
            } else {
                PR_Sleep(PR_SecondsToInterval(TIMEOUTSECS + 1));
            }
        } else {
            PR_Sleep(PR_SecondsToInterval(TIMEOUTSECS + 1));
        }
        if (debug_mode) {
            PR_fprintf(output, ".");
        }
        PR_Close(sock);
        sock = NULL;
    }
    if (debug_mode) {
        PR_fprintf(output, "\n");
    }

ErrorExit:
    if (sock != NULL) {
        PR_Close(sock);
    }
}


static void
RunTest(PRInt32 acceptType, PRInt32 clientAction)
{
    int i;

    /* First bind to the socket */
    listenSock = PR_NewTCPSocket();
    if (!listenSock) {
        failed_already=1;
        if (debug_mode) {
            PR_fprintf(output, "unable to create listen socket\n");
        }
        return;
    }
    memset(&listenAddr, 0, sizeof(listenAddr));
    listenAddr.inet.family = PR_AF_INET;
    listenAddr.inet.port = PR_htons(BASE_PORT);
    listenAddr.inet.ip = PR_htonl(PR_INADDR_ANY);
    /*
     * try a few times to bind server's address, if addresses are in
     * use
     */
    i = 0;
    while (PR_Bind(listenSock, &listenAddr) == PR_FAILURE) {
        if (PR_GetError() == PR_ADDRESS_IN_USE_ERROR) {
            listenAddr.inet.port += 2;
            if (i++ < SERVER_MAX_BIND_COUNT) {
                continue;
            }
        }
        failed_already=1;
        if (debug_mode) {
            PR_fprintf(output,"accept: ERROR - PR_Bind failed\n");
        }
        return;
    }


    rv = PR_Listen(listenSock, 100);
    if (rv == PR_FAILURE) {
        failed_already=1;
        if (debug_mode) {
            PR_fprintf(output, "unable to listen\n");
        }
        return;
    }

    clientCommand = clientAction;
    clientThread = PR_CreateThread(PR_USER_THREAD, ClientThread,
                                   (void *)&clientCommand, PR_PRIORITY_NORMAL, thread_scope,
                                   PR_JOINABLE_THREAD, 0);
    if (!clientThread) {
        failed_already=1;
        if (debug_mode) {
            PR_fprintf(output, "error creating client thread\n");
        }
        return;
    }

    iterations = count;
    for (; iterations--;) {
        switch (acceptType) {
            case ACCEPT_NORMAL:
                clientSock = PR_Accept(listenSock, &clientAddr,
                                       timeoutTime);
                switch(clientAction) {
                    case CLIENT_TIMEOUT_ACCEPT:
                        TEST_ASSERT(clientSock == 0);
                        TEST_ASSERT(PR_GetError() == PR_IO_TIMEOUT_ERROR);
                        break;
                    case CLIENT_NORMAL:
                        TEST_ASSERT(clientSock);
                        bytesRead = PR_Recv(clientSock,
                                            buf,  CLIENT_DATA,  0,  timeoutTime);
                        TEST_ASSERT(bytesRead == CLIENT_DATA);
                        break;
                    case CLIENT_TIMEOUT_SEND:
                        TEST_ASSERT(clientSock);
                        bytesRead = PR_Recv(clientSock,
                                            buf,  CLIENT_DATA,  0,  timeoutTime);
                        TEST_ASSERT(bytesRead == -1);
                        TEST_ASSERT(PR_GetError() == PR_IO_TIMEOUT_ERROR);
                        break;
                }
                break;
            case ACCEPT_READ:
                status = PR_AcceptRead(listenSock, &clientSock,
                                       &raddr, buf, CLIENT_DATA, timeoutTime);
                switch(clientAction) {
                    case CLIENT_TIMEOUT_ACCEPT:
                        /* Invalid test case */
                        TEST_ASSERT(0);
                        break;
                    case CLIENT_NORMAL:
                        TEST_ASSERT(clientSock);
                        TEST_ASSERT(status == CLIENT_DATA);
                        break;
                    case CLIENT_TIMEOUT_SEND:
                        TEST_ASSERT(status == -1);
                        TEST_ASSERT(PR_GetError() == PR_IO_TIMEOUT_ERROR);
                        break;
                }
                break;
#ifdef WINNT
            case ACCEPT_FAST:
                clientSock = PR_NTFast_Accept(listenSock,
                                              &clientAddr, timeoutTime);
                switch(clientAction) {
                    case CLIENT_TIMEOUT_ACCEPT:
                        TEST_ASSERT(clientSock == 0);
                        if (debug_mode) {
                            PR_fprintf(output, "PR_GetError is %ld\n", PR_GetError());
                        }
                        TEST_ASSERT(PR_GetError() == PR_IO_TIMEOUT_ERROR);
                        break;
                    case CLIENT_NORMAL:
                        TEST_ASSERT(clientSock);
                        bytesRead = PR_Recv(clientSock,
                                            buf,  CLIENT_DATA,  0,  timeoutTime);
                        TEST_ASSERT(bytesRead == CLIENT_DATA);
                        break;
                    case CLIENT_TIMEOUT_SEND:
                        TEST_ASSERT(clientSock);
                        bytesRead = PR_Recv(clientSock,
                                            buf,  CLIENT_DATA,  0,  timeoutTime);
                        TEST_ASSERT(bytesRead == -1);
                        TEST_ASSERT(PR_GetError() == PR_IO_TIMEOUT_ERROR);
                        break;
                }
                break;
                break;
            case ACCEPT_READ_FAST:
                status = PR_NTFast_AcceptRead(listenSock,
                                              &clientSock, &raddr, buf, 4096, timeoutTime);
                switch(clientAction) {
                    case CLIENT_TIMEOUT_ACCEPT:
                        /* Invalid test case */
                        TEST_ASSERT(0);
                        break;
                    case CLIENT_NORMAL:
                        TEST_ASSERT(clientSock);
                        TEST_ASSERT(status == CLIENT_DATA);
                        break;
                    case CLIENT_TIMEOUT_SEND:
                        TEST_ASSERT(clientSock == NULL);
                        TEST_ASSERT(status == -1);
                        TEST_ASSERT(PR_GetError() == PR_IO_TIMEOUT_ERROR);
                        break;
                }
                break;
            case ACCEPT_READ_FAST_CB:
                status = PR_NTFast_AcceptRead_WithTimeoutCallback(
                             listenSock, &clientSock, &raddr, buf, 4096,
                             timeoutTime, timeout_callback, (void *)CALLBACK_MAGIC);
                switch(clientAction) {
                    case CLIENT_TIMEOUT_ACCEPT:
                        /* Invalid test case */
                        TEST_ASSERT(0);
                        break;
                    case CLIENT_NORMAL:
                        TEST_ASSERT(clientSock);
                        TEST_ASSERT(status == CLIENT_DATA);
                        break;
                    case CLIENT_TIMEOUT_SEND:
                        if (debug_mode) {
                            PR_fprintf(output, "clientSock = 0x%8.8lx\n", clientSock);
                        }
                        TEST_ASSERT(clientSock == NULL);
                        TEST_ASSERT(status == -1);
                        TEST_ASSERT(PR_GetError() == PR_IO_TIMEOUT_ERROR);
                        break;
                }
                break;
#endif
        }
        if (clientSock != NULL) {
            PR_Close(clientSock);
            clientSock = NULL;
        }
    }
    PR_Close(listenSock);

    PR_JoinThread(clientThread);
}


void AcceptUpdatedTest(void)
{
    RunTest(ACCEPT_NORMAL, CLIENT_NORMAL);
}
void AcceptNotUpdatedTest(void)
{
    RunTest(ACCEPT_FAST, CLIENT_NORMAL);
}
void AcceptReadTest(void)
{
    RunTest(ACCEPT_READ, CLIENT_NORMAL);
}
void AcceptReadNotUpdatedTest(void)
{
    RunTest(ACCEPT_READ_FAST, CLIENT_NORMAL);
}
void AcceptReadCallbackTest(void)
{
    RunTest(ACCEPT_READ_FAST_CB, CLIENT_NORMAL);
}

void TimeoutAcceptUpdatedTest(void)
{
    RunTest(ACCEPT_NORMAL, CLIENT_TIMEOUT_ACCEPT);
}
void TimeoutAcceptNotUpdatedTest(void)
{
    RunTest(ACCEPT_FAST, CLIENT_TIMEOUT_ACCEPT);
}
void TimeoutAcceptReadCallbackTest(void)
{
    RunTest(ACCEPT_READ_FAST_CB, CLIENT_TIMEOUT_ACCEPT);
}

void TimeoutReadUpdatedTest(void)
{
    RunTest(ACCEPT_NORMAL, CLIENT_TIMEOUT_SEND);
}
void TimeoutReadNotUpdatedTest(void)
{
    RunTest(ACCEPT_FAST, CLIENT_TIMEOUT_SEND);
}
void TimeoutReadReadTest(void)
{
    RunTest(ACCEPT_READ, CLIENT_TIMEOUT_SEND);
}
void TimeoutReadReadNotUpdatedTest(void)
{
    RunTest(ACCEPT_READ_FAST, CLIENT_TIMEOUT_SEND);
}
void TimeoutReadReadCallbackTest(void)
{
    RunTest(ACCEPT_READ_FAST_CB, CLIENT_TIMEOUT_SEND);
}

/************************************************************************/

static void Measure(void (*func)(void), const char *msg)
{
    PRIntervalTime start, stop;
    double d;

    start = PR_IntervalNow();
    (*func)();
    stop = PR_IntervalNow();

    d = (double)PR_IntervalToMicroseconds(stop - start);
    if (debug_mode) {
        PR_fprintf(output, "%40s: %6.2f usec\n", msg, d / count);
    }

}

int main(int argc, char **argv)
{

    /* The command line argument: -d is used to determine if the test is being run
    in debug mode. The regress tool requires only one line output:PASS or FAIL.
    All of the printfs associated with this test has been handled with a if (debug_mode)
    test.
    Usage: test_name [-d] [-c n]
    */
    PLOptStatus os;
    PLOptState *opt = PL_CreateOptState(argc, argv, "Gdc:");
    while (PL_OPT_EOL != (os = PL_GetNextOpt(opt)))
    {
        if (PL_OPT_BAD == os) {
            continue;
        }
        switch (opt->option)
        {
            case 'G':  /* global threads */
                thread_scope = PR_GLOBAL_THREAD;
                break;
            case 'd':  /* debug mode */
                debug_mode = 1;
                break;
            case 'c':  /* loop counter */
                count = atoi(opt->value);
                break;
            default:
                break;
        }
    }
    PL_DestroyOptState(opt);

    PR_Init(PR_USER_THREAD, PR_PRIORITY_NORMAL, 0);
    output = PR_STDERR;
    PR_STDIO_INIT();

    timeoutTime = PR_SecondsToInterval(TIMEOUTSECS);
    if (debug_mode) {
        PR_fprintf(output, "\nRun accept() sucessful connection tests\n");
    }

    Measure(AcceptUpdatedTest, "PR_Accept()");
    Measure(AcceptReadTest, "PR_AcceptRead()");
#ifdef WINNT
    Measure(AcceptNotUpdatedTest, "PR_NTFast_Accept()");
    Measure(AcceptReadNotUpdatedTest, "PR_NTFast_AcceptRead()");
    Measure(AcceptReadCallbackTest, "PR_NTFast_AcceptRead_WithTimeoutCallback()");
#endif
    if (debug_mode) {
        PR_fprintf(output, "\nRun accept() timeout in the accept tests\n");
    }
#ifdef WINNT
    Measure(TimeoutReadReadCallbackTest, "PR_NTFast_AcceptRead_WithTimeoutCallback()");
#endif
    Measure(TimeoutReadUpdatedTest, "PR_Accept()");
    if (debug_mode) {
        PR_fprintf(output, "\nRun accept() timeout in the read tests\n");
    }
    Measure(TimeoutReadReadTest, "PR_AcceptRead()");
#ifdef WINNT
    Measure(TimeoutReadNotUpdatedTest, "PR_NTFast_Accept()");
    Measure(TimeoutReadReadNotUpdatedTest, "PR_NTFast_AcceptRead()");
    Measure(TimeoutReadReadCallbackTest, "PR_NTFast_AcceptRead_WithTimeoutCallback()");
#endif
    PR_fprintf(output, "%s\n", (failed_already) ? "FAIL" : "PASS");
    return failed_already;
}