/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */

#include "nsCOMPtr.h"
#include "gtest/gtest.h"

#include "mozilla/gtest/MozAssertions.h"
#include "mozilla/Unused.h"

#define NS_IFOO_IID                                  \
  {                                                  \
    0x6f7652e0, 0xee43, 0x11d1, {                    \
      0x9c, 0xc3, 0x00, 0x60, 0x08, 0x8c, 0xa6, 0xb3 \
    }                                                \
  }

namespace TestCOMPtr {

class IFoo : public nsISupports {
 public:
  NS_DECLARE_STATIC_IID_ACCESSOR(NS_IFOO_IID)

 public:
  IFoo();
  // virtual dtor because IBar uses our Release()
  virtual ~IFoo();

  NS_IMETHOD_(MozExternalRefCountType) AddRef() override;
  NS_IMETHOD_(MozExternalRefCountType) Release() override;
  NS_IMETHOD QueryInterface(const nsIID&, void**) override;

  unsigned int refcount_;

  static int total_constructions_;
  static int total_destructions_;
  static int total_queries_;
};

NS_DEFINE_STATIC_IID_ACCESSOR(IFoo, NS_IFOO_IID)

int IFoo::total_constructions_;
int IFoo::total_destructions_;
int IFoo::total_queries_;

IFoo::IFoo() : refcount_(0) { ++total_constructions_; }

IFoo::~IFoo() { ++total_destructions_; }

MozExternalRefCountType IFoo::AddRef() {
  ++refcount_;
  return refcount_;
}

MozExternalRefCountType IFoo::Release() {
  int newcount = --refcount_;

  if (newcount == 0) {
    delete this;
  }

  return newcount;
}

nsresult IFoo::QueryInterface(const nsIID& aIID, void** aResult) {
  total_queries_++;

  nsISupports* rawPtr = 0;
  nsresult status = NS_OK;

  if (aIID.Equals(NS_GET_IID(IFoo)))
    rawPtr = this;
  else {
    nsID iid_of_ISupports = NS_ISUPPORTS_IID;
    if (aIID.Equals(iid_of_ISupports))
      rawPtr = static_cast<nsISupports*>(this);
    else
      status = NS_ERROR_NO_INTERFACE;
  }

  NS_IF_ADDREF(rawPtr);
  *aResult = rawPtr;

  return status;
}

static nsresult CreateIFoo(void** result)
// a typical factory function (that calls AddRef)
{
  auto* foop = new IFoo;

  foop->AddRef();
  *result = foop;

  return NS_OK;
}

static void set_a_IFoo(nsCOMPtr<IFoo>* result) {
  // Various places in this file do a static_cast to nsISupports* in order to
  // make the QI non-trivial, to avoid hitting a static assert.
  nsCOMPtr<IFoo> foop(do_QueryInterface(static_cast<nsISupports*>(new IFoo)));
  *result = foop;
}

static nsCOMPtr<IFoo> return_a_IFoo() {
  nsCOMPtr<IFoo> foop(do_QueryInterface(static_cast<nsISupports*>(new IFoo)));
  return foop;
}

#define NS_IBAR_IID                                  \
  {                                                  \
    0x6f7652e1, 0xee43, 0x11d1, {                    \
      0x9c, 0xc3, 0x00, 0x60, 0x08, 0x8c, 0xa6, 0xb3 \
    }                                                \
  }

class IBar : public IFoo {
 public:
  NS_DECLARE_STATIC_IID_ACCESSOR(NS_IBAR_IID)

 public:
  IBar();
  ~IBar() override;

  NS_IMETHOD QueryInterface(const nsIID&, void**) override;

  static int total_destructions_;
  static int total_queries_;
};

NS_DEFINE_STATIC_IID_ACCESSOR(IBar, NS_IBAR_IID)

int IBar::total_destructions_;
int IBar::total_queries_;

IBar::IBar() = default;

IBar::~IBar() { total_destructions_++; }

nsresult IBar::QueryInterface(const nsID& aIID, void** aResult) {
  total_queries_++;

  nsISupports* rawPtr = 0;
  nsresult status = NS_OK;

  if (aIID.Equals(NS_GET_IID(IBar)))
    rawPtr = this;
  else if (aIID.Equals(NS_GET_IID(IFoo)))
    rawPtr = static_cast<IFoo*>(this);
  else {
    nsID iid_of_ISupports = NS_ISUPPORTS_IID;
    if (aIID.Equals(iid_of_ISupports))
      rawPtr = static_cast<nsISupports*>(this);
    else
      status = NS_ERROR_NO_INTERFACE;
  }

  NS_IF_ADDREF(rawPtr);
  *aResult = rawPtr;

  return status;
}

static nsresult CreateIBar(void** result)
// a typical factory function (that calls AddRef)
{
  auto* barp = new IBar;

  barp->AddRef();
  *result = barp;

  return NS_OK;
}

static void AnIFooPtrPtrContext(IFoo**) {}

static void AVoidPtrPtrContext(void**) {}

static void AnISupportsPtrPtrContext(nsISupports**) {}

}  // namespace TestCOMPtr

using namespace TestCOMPtr;

TEST(COMPtr, Bloat_Raw_Unsafe)
{
  // ER: I'm not sure what this is testing...
  IBar* barP = 0;
  nsresult rv = CreateIBar(reinterpret_cast<void**>(&barP));
  ASSERT_NS_SUCCEEDED(rv);
  ASSERT_TRUE(barP);

  IFoo* fooP = 0;
  rv = barP->QueryInterface(NS_GET_IID(IFoo), reinterpret_cast<void**>(&fooP));
  ASSERT_NS_SUCCEEDED(rv);
  ASSERT_TRUE(fooP);

  NS_RELEASE(fooP);
  NS_RELEASE(barP);
}

TEST(COMPtr, Bloat_Smart)
{
  // ER: I'm not sure what this is testing...
  nsCOMPtr<IBar> barP;
  nsresult rv = CreateIBar(getter_AddRefs(barP));
  ASSERT_NS_SUCCEEDED(rv);
  ASSERT_TRUE(barP);

  nsCOMPtr<IFoo> fooP(do_QueryInterface(static_cast<nsISupports*>(barP), &rv));
  ASSERT_NS_SUCCEEDED(rv);
  ASSERT_TRUE(fooP);
}

TEST(COMPtr, AddRefAndRelease)
{
  IFoo::total_constructions_ = 0;
  IFoo::total_destructions_ = 0;
  IBar::total_destructions_ = 0;

  {
    nsCOMPtr<IFoo> foop(do_QueryInterface(static_cast<nsISupports*>(new IFoo)));
    ASSERT_EQ(foop->refcount_, (unsigned int)1);
    ASSERT_EQ(IFoo::total_constructions_, 1);
    ASSERT_EQ(IFoo::total_destructions_, 0);

    foop = do_QueryInterface(static_cast<nsISupports*>(new IFoo));
    ASSERT_EQ(foop->refcount_, (unsigned int)1);
    ASSERT_EQ(IFoo::total_constructions_, 2);
    ASSERT_EQ(IFoo::total_destructions_, 1);

    // [Shouldn't compile] Is it a compile time error to try to |AddRef| by
    // hand?
    // foop->AddRef();

    // [Shouldn't compile] Is it a compile time error to try to |Release| be
    // hand?
    // foop->Release();

    // [Shouldn't compile] Is it a compile time error to try to |delete| an
    // |nsCOMPtr|?
    // delete foop;

    static_cast<IFoo*>(foop)->AddRef();
    ASSERT_EQ(foop->refcount_, (unsigned int)2);
    ASSERT_EQ(IFoo::total_constructions_, 2);
    ASSERT_EQ(IFoo::total_destructions_, 1);

    static_cast<IFoo*>(foop)->Release();
    ASSERT_EQ(foop->refcount_, (unsigned int)1);
    ASSERT_EQ(IFoo::total_constructions_, 2);
    ASSERT_EQ(IFoo::total_destructions_, 1);
  }

  ASSERT_EQ(IFoo::total_constructions_, 2);
  ASSERT_EQ(IFoo::total_destructions_, 2);

  {
    nsCOMPtr<IFoo> foop(do_QueryInterface(static_cast<nsISupports*>(new IBar)));
    mozilla::Unused << foop;
  }

  ASSERT_EQ(IBar::total_destructions_, 1);
}

TEST(COMPtr, Comparison)
{
  IFoo::total_constructions_ = 0;
  IFoo::total_destructions_ = 0;

  {
    nsCOMPtr<IFoo> foo1p(
        do_QueryInterface(static_cast<nsISupports*>(new IFoo)));
    nsCOMPtr<IFoo> foo2p(
        do_QueryInterface(static_cast<nsISupports*>(new IFoo)));

    ASSERT_EQ(IFoo::total_constructions_, 2);

    // Test != operator
    ASSERT_NE(foo1p, foo2p);
    ASSERT_NE(foo1p, foo2p.get());

    // Test == operator
    foo1p = foo2p;

    ASSERT_EQ(IFoo::total_destructions_, 1);

    ASSERT_EQ(foo1p, foo2p);
    ASSERT_EQ(foo2p, foo2p.get());
    ASSERT_EQ(foo2p.get(), foo2p);

    // Test () operator
    ASSERT_TRUE(foo1p);

    ASSERT_EQ(foo1p->refcount_, (unsigned int)2);
    ASSERT_EQ(foo2p->refcount_, (unsigned int)2);
  }

  ASSERT_EQ(IFoo::total_destructions_, 2);
}

TEST(COMPtr, DontAddRef)
{
  {
    auto* raw_foo1p = new IFoo;
    raw_foo1p->AddRef();

    auto* raw_foo2p = new IFoo;
    raw_foo2p->AddRef();

    nsCOMPtr<IFoo> foo1p(dont_AddRef(raw_foo1p));
    ASSERT_EQ(raw_foo1p, foo1p);
    ASSERT_EQ(foo1p->refcount_, (unsigned int)1);

    nsCOMPtr<IFoo> foo2p;
    foo2p = dont_AddRef(raw_foo2p);
    ASSERT_EQ(raw_foo2p, foo2p);
    ASSERT_EQ(foo2p->refcount_, (unsigned int)1);
  }
}

TEST(COMPtr, AssignmentHelpers)
{
  IFoo::total_constructions_ = 0;
  IFoo::total_destructions_ = 0;

  {
    nsCOMPtr<IFoo> foop;
    ASSERT_FALSE(foop);
    CreateIFoo(nsGetterAddRefs<IFoo>(foop));
    ASSERT_TRUE(foop);
  }

  ASSERT_EQ(IFoo::total_constructions_, 1);
  ASSERT_EQ(IFoo::total_destructions_, 1);

  {
    nsCOMPtr<IFoo> foop;
    ASSERT_FALSE(foop);
    CreateIFoo(getter_AddRefs(foop));
    ASSERT_TRUE(foop);
  }

  ASSERT_EQ(IFoo::total_constructions_, 2);
  ASSERT_EQ(IFoo::total_destructions_, 2);

  {
    nsCOMPtr<IFoo> foop;
    ASSERT_FALSE(foop);
    set_a_IFoo(address_of(foop));
    ASSERT_TRUE(foop);

    ASSERT_EQ(IFoo::total_constructions_, 3);
    ASSERT_EQ(IFoo::total_destructions_, 2);

    foop = return_a_IFoo();
    ASSERT_TRUE(foop);

    ASSERT_EQ(IFoo::total_constructions_, 4);
    ASSERT_EQ(IFoo::total_destructions_, 3);
  }

  ASSERT_EQ(IFoo::total_constructions_, 4);
  ASSERT_EQ(IFoo::total_destructions_, 4);

  {
    nsCOMPtr<IFoo> fooP(do_QueryInterface(static_cast<nsISupports*>(new IFoo)));
    ASSERT_TRUE(fooP);

    ASSERT_EQ(IFoo::total_constructions_, 5);
    ASSERT_EQ(IFoo::total_destructions_, 4);

    nsCOMPtr<IFoo> fooP2(std::move(fooP));
    ASSERT_TRUE(fooP2);

    ASSERT_EQ(IFoo::total_constructions_, 5);
    ASSERT_EQ(IFoo::total_destructions_, 4);
  }

  ASSERT_EQ(IFoo::total_constructions_, 5);
  ASSERT_EQ(IFoo::total_destructions_, 5);
}

TEST(COMPtr, QueryInterface)
{
  IFoo::total_queries_ = 0;
  IBar::total_queries_ = 0;

  {
    nsCOMPtr<IFoo> fooP;
    ASSERT_FALSE(fooP);
    fooP = do_QueryInterface(static_cast<nsISupports*>(new IFoo));
    ASSERT_TRUE(fooP);
    ASSERT_EQ(IFoo::total_queries_, 1);

    nsCOMPtr<IFoo> foo2P;

    // Test that |QueryInterface| _not_ called when assigning a smart-pointer
    // of the same type.);
    foo2P = fooP;
    ASSERT_EQ(IFoo::total_queries_, 1);
  }

  {
    nsCOMPtr<IBar> barP(do_QueryInterface(static_cast<nsISupports*>(new IBar)));
    ASSERT_EQ(IBar::total_queries_, 1);

    // Test that |QueryInterface| is called when assigning a smart-pointer of
    // a different type.
    nsCOMPtr<IFoo> fooP(do_QueryInterface(static_cast<nsISupports*>(barP)));
    ASSERT_EQ(IBar::total_queries_, 2);
    ASSERT_EQ(IFoo::total_queries_, 1);
    ASSERT_TRUE(fooP);
  }
}

TEST(COMPtr, GetterConversions)
{
  // This is just a compilation test. We add a few asserts to keep gtest happy.
  {
    nsCOMPtr<IFoo> fooP;
    ASSERT_FALSE(fooP);

    AnIFooPtrPtrContext(getter_AddRefs(fooP));
    AVoidPtrPtrContext(getter_AddRefs(fooP));
  }

  {
    nsCOMPtr<nsISupports> supportsP;
    ASSERT_FALSE(supportsP);

    AVoidPtrPtrContext(getter_AddRefs(supportsP));
    AnISupportsPtrPtrContext(getter_AddRefs(supportsP));
  }
}