# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

import re
from copy import deepcopy
from collections import OrderedDict
import itertools

import ipdl.ast
import ipdl.builtin
from ipdl.cxx.ast import *
from ipdl.cxx.code import *
from ipdl.type import ActorType, UnionType, TypeVisitor, builtinHeaderIncludes
from ipdl.util import hash_str


# -----------------------------------------------------------------------------
# "Public" interface to lowering
##


class LowerToCxx:
    def lower(self, tu, segmentcapacitydict):
        """returns |[ header: File ], [ cpp : File ]| representing the
        lowered form of |tu|"""
        # annotate the AST with IPDL/C++ IR-type stuff used later
        tu.accept(_DecorateWithCxxStuff())

        # Any modifications to the filename scheme here need corresponding
        # modifications in the ipdl.py driver script.
        name = tu.name
        pheader, pcpp = File(name + ".h"), File(name + ".cpp")

        _GenerateProtocolCode().lower(tu, pheader, pcpp, segmentcapacitydict)
        headers = [pheader]
        cpps = [pcpp]

        if tu.protocol:
            pname = tu.protocol.name

            parentheader, parentcpp = (
                File(pname + "Parent.h"),
                File(pname + "Parent.cpp"),
            )
            _GenerateProtocolParentCode().lower(
                tu, pname + "Parent", parentheader, parentcpp
            )

            childheader, childcpp = File(pname + "Child.h"), File(pname + "Child.cpp")
            _GenerateProtocolChildCode().lower(
                tu, pname + "Child", childheader, childcpp
            )

            headers += [parentheader, childheader]
            cpps += [parentcpp, childcpp]

        return headers, cpps


# -----------------------------------------------------------------------------
# Helper code
##


def hashfunc(value):
    h = hash_str(value) % 2 ** 32
    if h < 0:
        h += 2 ** 32
    return h


_NULL_ACTOR_ID = ExprLiteral.ZERO
_FREED_ACTOR_ID = ExprLiteral.ONE

_DISCLAIMER = Whitespace(
    """//
// Automatically generated by ipdlc.
// Edit at your own risk
//

"""
)


class _struct:
    pass


def _namespacedHeaderName(name, namespaces):
    pfx = "/".join([ns.name for ns in namespaces])
    if pfx:
        return pfx + "/" + name
    else:
        return name


def _ipdlhHeaderName(tu):
    assert tu.filetype == "header"
    return _namespacedHeaderName(tu.name, tu.namespaces)


def _protocolHeaderName(p, side=""):
    if side:
        side = side.title()
    base = p.name + side
    return _namespacedHeaderName(base, p.namespaces)


def _includeGuardMacroName(headerfile):
    return re.sub(r"[./]", "_", headerfile.name)


def _includeGuardStart(headerfile):
    guard = _includeGuardMacroName(headerfile)
    return [CppDirective("ifndef", guard), CppDirective("define", guard)]


def _includeGuardEnd(headerfile):
    guard = _includeGuardMacroName(headerfile)
    return [CppDirective("endif", "// ifndef " + guard)]


def _messageStartName(ptype):
    return ptype.name() + "MsgStart"


def _protocolId(ptype):
    return ExprVar(_messageStartName(ptype))


def _protocolIdType():
    return Type.INT32


def _actorName(pname, side):
    """|pname| is the protocol name. |side| is 'Parent' or 'Child'."""
    tag = side
    if not tag[0].isupper():
        tag = side.title()
    return pname + tag


def _actorIdType():
    return Type.INT32


def _actorTypeTagType():
    return Type.INT32


def _actorId(actor=None):
    if actor is not None:
        return ExprCall(ExprSelect(actor, "->", "Id"))
    return ExprCall(ExprVar("Id"))


def _actorHId(actorhandle):
    return ExprSelect(actorhandle, ".", "mId")


def _backstagePass():
    return ExprCall(ExprVar("mozilla::ipc::PrivateIPDLInterface"))


def _deleteId():
    return ExprVar("Msg___delete____ID")


def _deleteReplyId():
    return ExprVar("Reply___delete____ID")


def _lookupListener(idexpr):
    return ExprCall(ExprVar("Lookup"), args=[idexpr])


def _makeForwardDeclForQClass(clsname, quals, cls=True, struct=False):
    fd = ForwardDecl(clsname, cls=cls, struct=struct)
    if 0 == len(quals):
        return fd

    outerns = Namespace(quals[0])
    innerns = outerns
    for ns in quals[1:]:
        tmpns = Namespace(ns)
        innerns.addstmt(tmpns)
        innerns = tmpns

    innerns.addstmt(fd)
    return outerns


def _makeForwardDeclForActor(ptype, side):
    return _makeForwardDeclForQClass(
        _actorName(ptype.qname.baseid, side), ptype.qname.quals
    )


def _makeForwardDecl(type):
    return _makeForwardDeclForQClass(type.name(), type.qname.quals)


def _putInNamespaces(cxxthing, namespaces):
    """|namespaces| is in order [ outer, ..., inner ]"""
    if 0 == len(namespaces):
        return cxxthing

    outerns = Namespace(namespaces[0].name)
    innerns = outerns
    for ns in namespaces[1:]:
        newns = Namespace(ns.name)
        innerns.addstmt(newns)
        innerns = newns
    innerns.addstmt(cxxthing)
    return outerns


def _sendPrefix(msgtype):
    """Prefix of the name of the C++ method that sends |msgtype|."""
    if msgtype.isInterrupt():
        return "Call"
    return "Send"


def _recvPrefix(msgtype):
    """Prefix of the name of the C++ method that handles |msgtype|."""
    if msgtype.isInterrupt():
        return "Answer"
    return "Recv"


def _flatTypeName(ipdltype):
    """Return a 'flattened' IPDL type name that can be used as an
    identifier.
    E.g., |Foo[]| --> |ArrayOfFoo|."""
    # NB: this logic depends heavily on what IPDL types are allowed to
    # be constructed; e.g., Foo[][] is disallowed.  needs to be kept in
    # sync with grammar.
    if ipdltype.isIPDL() and ipdltype.isArray():
        return "ArrayOf" + _flatTypeName(ipdltype.basetype)
    if ipdltype.isIPDL() and ipdltype.isMaybe():
        return "Maybe" + _flatTypeName(ipdltype.basetype)
    # NotNull types just assume the underlying variant name to avoid unnecessary
    # noise, as a NotNull<T> and T should never exist in the same union.
    if ipdltype.isIPDL() and ipdltype.isNotNull():
        return _flatTypeName(ipdltype.basetype)
    return ipdltype.name()


def _hasVisibleActor(ipdltype):
    """Return true iff a C++ decl of |ipdltype| would have an Actor* type.
    For example: |Actor[]| would turn into |Array<ActorParent*>|, so this
    function would return true for |Actor[]|."""
    return ipdltype.isIPDL() and (
        ipdltype.isActor()
        or (ipdltype.hasBaseType() and _hasVisibleActor(ipdltype.basetype))
    )


def _abortIfFalse(cond, msg):
    return StmtExpr(
        ExprCall(ExprVar("MOZ_RELEASE_ASSERT"), [cond, ExprLiteral.String(msg)])
    )


def _refptr(T):
    return Type("RefPtr", T=T)


def _alreadyaddrefed(T):
    return Type("already_AddRefed", T=T)


def _tuple(types, const=False, ref=False):
    return Type("std::tuple", T=types, const=const, ref=ref)


def _promise(resolvetype, rejecttype, tail, resolver=False):
    inner = Type("Private") if resolver else None
    return Type("MozPromise", T=[resolvetype, rejecttype, tail], inner=inner)


def _makePromise(returns, side, resolver=False):
    if len(returns) > 1:
        resolvetype = _tuple([d.bareType(side) for d in returns])
    else:
        resolvetype = returns[0].bareType(side)

    # MozPromise is purposefully made to be exclusive only. Really, we mean it.
    return _promise(
        resolvetype, _ResponseRejectReason.Type(), ExprLiteral.TRUE, resolver=resolver
    )


def _resolveType(returns, side):
    if len(returns) > 1:
        return _tuple([d.inType(side, "send") for d in returns])
    return returns[0].inType(side, "send")


def _makeResolver(returns, side):
    return TypeFunction([Decl(_resolveType(returns, side), "")])


def _cxxArrayType(basetype, const=False, ref=False):
    return Type("nsTArray", T=basetype, const=const, ref=ref, hasimplicitcopyctor=False)


def _cxxSpanType(basetype, const=False, ref=False):
    basetype = deepcopy(basetype)
    basetype.rightconst = True
    return Type(
        "mozilla::Span", T=basetype, const=const, ref=ref, hasimplicitcopyctor=True
    )


def _cxxMaybeType(basetype, const=False, ref=False):
    return Type(
        "mozilla::Maybe",
        T=basetype,
        const=const,
        ref=ref,
        hasimplicitcopyctor=basetype.hasimplicitcopyctor,
    )


def _cxxReadResultType(basetype, const=False, ref=False):
    return Type(
        "IPC::ReadResult",
        T=basetype,
        const=const,
        ref=ref,
        hasimplicitcopyctor=basetype.hasimplicitcopyctor,
    )


def _cxxNotNullType(basetype, const=False, ref=False):
    return Type(
        "mozilla::NotNull",
        T=basetype,
        const=const,
        ref=ref,
        hasimplicitcopyctor=basetype.hasimplicitcopyctor,
    )


def _cxxManagedContainerType(basetype, const=False, ref=False):
    return Type(
        "ManagedContainer", T=basetype, const=const, ref=ref, hasimplicitcopyctor=False
    )


def _cxxLifecycleProxyType(ptr=False):
    return Type("mozilla::ipc::ActorLifecycleProxy", ptr=ptr)


def _otherSide(side):
    if side == "child":
        return "parent"
    if side == "parent":
        return "child"
    assert 0


def _ifLogging(topLevelProtocol, stmts):
    return StmtCode(
        """
        if (mozilla::ipc::LoggingEnabledFor(${proto})) {
            $*{stmts}
        }
        """,
        proto=topLevelProtocol,
        stmts=stmts,
    )


# XXX we need to remove these and install proper error handling


def _printErrorMessage(msg):
    if isinstance(msg, str):
        msg = ExprLiteral.String(msg)
    return StmtExpr(ExprCall(ExprVar("NS_ERROR"), args=[msg]))


def _protocolErrorBreakpoint(msg):
    if isinstance(msg, str):
        msg = ExprLiteral.String(msg)
    return StmtExpr(
        ExprCall(ExprVar("mozilla::ipc::ProtocolErrorBreakpoint"), args=[msg])
    )


def _printWarningMessage(msg):
    if isinstance(msg, str):
        msg = ExprLiteral.String(msg)
    return StmtExpr(ExprCall(ExprVar("NS_WARNING"), args=[msg]))


def _fatalError(msg):
    return StmtExpr(ExprCall(ExprVar("FatalError"), args=[ExprLiteral.String(msg)]))


def _logicError(msg):
    return StmtExpr(
        ExprCall(ExprVar("mozilla::ipc::LogicError"), args=[ExprLiteral.String(msg)])
    )


def _sentinelReadError(classname):
    return StmtExpr(
        ExprCall(
            ExprVar("mozilla::ipc::SentinelReadError"),
            args=[ExprLiteral.String(classname)],
        )
    )


# Results that IPDL-generated code returns back to *Channel code.
# Users never see these


class _Result:
    @staticmethod
    def Type():
        return Type("Result")

    Processed = ExprVar("MsgProcessed")
    NotKnown = ExprVar("MsgNotKnown")
    NotAllowed = ExprVar("MsgNotAllowed")
    PayloadError = ExprVar("MsgPayloadError")
    ProcessingError = ExprVar("MsgProcessingError")
    RouteError = ExprVar("MsgRouteError")
    ValuError = ExprVar("MsgValueError")  # [sic]


# these |errfn*| are functions that generate code to be executed on an
# error, such as "bad actor ID".  each is given a Python string
# containing a description of the error

# used in user-facing Send*() methods


def errfnSend(msg, errcode=ExprLiteral.FALSE):
    return [_fatalError(msg), StmtReturn(errcode)]


def errfnSendCtor(msg):
    return errfnSend(msg, errcode=ExprLiteral.NULL)


# TODO should this error handling be strengthened for dtors?


def errfnSendDtor(msg):
    return [_printErrorMessage(msg), StmtReturn.FALSE]


# used in |OnMessage*()| handlers that hand in-messages off to Recv*()
# interface methods


def errfnRecv(msg, errcode=_Result.ValuError):
    return [_fatalError(msg), StmtReturn(errcode)]


def errfnSentinel(rvalue=ExprLiteral.FALSE):
    def inner(msg):
        return [_sentinelReadError(msg), StmtReturn(rvalue)]

    return inner


def _destroyMethod():
    return ExprVar("ActorDestroy")


def errfnUnreachable(msg):
    return [_logicError(msg)]


def readResultError():
    return ExprCode("{}")


class _DestroyReason:
    @staticmethod
    def Type():
        return Type("ActorDestroyReason")

    Deletion = ExprVar("Deletion")
    AncestorDeletion = ExprVar("AncestorDeletion")
    NormalShutdown = ExprVar("NormalShutdown")
    AbnormalShutdown = ExprVar("AbnormalShutdown")
    FailedConstructor = ExprVar("FailedConstructor")
    ManagedEndpointDropped = ExprVar("ManagedEndpointDropped")


class _ResponseRejectReason:
    @staticmethod
    def Type():
        return Type("ResponseRejectReason")

    SendError = ExprVar("ResponseRejectReason::SendError")
    ChannelClosed = ExprVar("ResponseRejectReason::ChannelClosed")
    HandlerRejected = ExprVar("ResponseRejectReason::HandlerRejected")
    ActorDestroyed = ExprVar("ResponseRejectReason::ActorDestroyed")


# -----------------------------------------------------------------------------
# Intermediate representation (IR) nodes used during lowering


class _ConvertToCxxType(TypeVisitor):
    def __init__(self, side, fq):
        self.side = side
        self.fq = fq

    def typename(self, thing):
        if self.fq:
            return thing.fullname()
        return thing.name()

    def visitImportedCxxType(self, t):
        cxxtype = Type(self.typename(t))
        if t.isRefcounted():
            cxxtype = _refptr(cxxtype)
        return cxxtype

    def visitBuiltinCType(self, b):
        return Type(self.typename(b))

    def visitActorType(self, a):
        if self.side is None:
            return Type(
                "::mozilla::ipc::SideVariant",
                T=[
                    _cxxBareType(a, "parent", self.fq),
                    _cxxBareType(a, "child", self.fq),
                ],
            )
        return Type(_actorName(self.typename(a.protocol), self.side), ptr=True)

    def visitStructType(self, s):
        return Type(self.typename(s))

    def visitUnionType(self, u):
        return Type(self.typename(u))

    def visitArrayType(self, a):
        basecxxtype = a.basetype.accept(self)
        return _cxxArrayType(basecxxtype)

    def visitMaybeType(self, m):
        basecxxtype = m.basetype.accept(self)
        return _cxxMaybeType(basecxxtype)

    def visitShmemType(self, s):
        return Type(self.typename(s))

    def visitByteBufType(self, s):
        return Type(self.typename(s))

    def visitFDType(self, s):
        return Type(self.typename(s))

    def visitEndpointType(self, s):
        return Type(self.typename(s))

    def visitManagedEndpointType(self, s):
        return Type(self.typename(s))

    def visitUniquePtrType(self, s):
        return Type(self.typename(s))

    def visitNotNullType(self, n):
        basecxxtype = n.basetype.accept(self)
        return _cxxNotNullType(basecxxtype)

    def visitProtocolType(self, p):
        assert 0

    def visitMessageType(self, m):
        assert 0

    def visitVoidType(self, v):
        assert 0


def _cxxBareType(ipdltype, side, fq=False):
    return ipdltype.accept(_ConvertToCxxType(side, fq))


def _cxxRefType(ipdltype, side):
    t = _cxxBareType(ipdltype, side)
    t.ref = True
    return t


def _cxxConstRefType(ipdltype, side):
    t = _cxxBareType(ipdltype, side)
    if ipdltype.isIPDL() and ipdltype.isActor():
        return t
    if ipdltype.isIPDL() and ipdltype.isShmem():
        t.ref = True
        return t
    if ipdltype.isIPDL() and ipdltype.isNotNull():
        # If the inner type chooses to use a raw pointer, wrap that instead.
        inner = _cxxConstRefType(ipdltype.basetype, side)
        if inner.ptr:
            t = _cxxNotNullType(inner)
            return t
    if ipdltype.isIPDL() and ipdltype.hasBaseType():
        # Keep same constness as inner type.
        inner = _cxxConstRefType(ipdltype.basetype, side)
        t.const = inner.const or not inner.ref
        t.ref = True
        return t
    if ipdltype.isCxx() and (ipdltype.isSendMoveOnly() or ipdltype.isDataMoveOnly()):
        t.const = True
        t.ref = True
        return t
    if ipdltype.isCxx() and ipdltype.isRefcounted():
        # Use T* instead of const RefPtr<T>&
        t = t.T
        t.ptr = True
        return t
    t.const = True
    t.ref = True
    return t


def _cxxTypeNeedsMoveForSend(ipdltype, context="root", visited=None):
    """Returns `True` if serializing ipdltype requires a mutable reference, e.g.
    because the underlying resource represented by the value is being
    transferred to another process. This is occasionally distinct from whether
    the C++ type exposes a copy constructor, such as for types which are not
    cheaply copiable, but are not mutated when serialized."""

    if visited is None:
        visited = set()

    visited.add(ipdltype)

    if ipdltype.isCxx():
        return ipdltype.isSendMoveOnly()

    if ipdltype.isIPDL():
        if ipdltype.hasBaseType():
            return _cxxTypeNeedsMoveForSend(ipdltype.basetype, "wrapper", visited)
        if ipdltype.isStruct() or ipdltype.isUnion():
            return any(
                _cxxTypeNeedsMoveForSend(t, "compound", visited)
                for t in ipdltype.itercomponents()
                if t not in visited
            )

        # For historical reasons, shmem is `const_cast` to a mutable reference
        # when being stored in a struct or union (see
        # `_StructField.constRefExpr` and `_UnionMember.getConstValue`), meaning
        # that they do not cause the containing struct to require move for
        # sending.
        if ipdltype.isShmem():
            return context != "compound"

        return (
            ipdltype.isByteBuf()
            or ipdltype.isEndpoint()
            or ipdltype.isManagedEndpoint()
        )

    return False


def _cxxTypeNeedsMoveForData(ipdltype, context="root", visited=None):
    """Returns `True` if the bare C++ type corresponding to ipdltype does not
    satisfy std::is_copy_constructible_v<T>. All C++ types supported by IPDL
    must support std::is_move_constructible_v<T>, so non-movable types must be
    passed behind a `UniquePtr`."""

    if visited is None:
        visited = set()

    visited.add(ipdltype)

    if ipdltype.isCxx():
        return ipdltype.isDataMoveOnly()

    if ipdltype.isIPDL():
        if ipdltype.isUniquePtr():
            return True

        # When nested within a maybe or array, arrays are no longer copyable.
        if context == "wrapper" and ipdltype.isArray():
            return True
        if ipdltype.hasBaseType():
            return _cxxTypeNeedsMoveForData(ipdltype.basetype, "wrapper", visited)
        if ipdltype.isStruct() or ipdltype.isUnion():
            return any(
                _cxxTypeNeedsMoveForData(t, "compound", visited)
                for t in ipdltype.itercomponents()
                if t not in visited
            )
        return (
            ipdltype.isByteBuf()
            or ipdltype.isEndpoint()
            or ipdltype.isManagedEndpoint()
        )

    return False


def _cxxTypeCanMove(ipdltype):
    return not (ipdltype.isIPDL() and ipdltype.isActor())


def _cxxForceMoveRefType(ipdltype, side):
    assert _cxxTypeCanMove(ipdltype)
    t = _cxxBareType(ipdltype, side)
    t.rvalref = True
    return t


def _cxxPtrToType(ipdltype, side):
    t = _cxxBareType(ipdltype, side)
    if ipdltype.isIPDL() and ipdltype.isActor() and side is not None:
        t.ptr = False
        t.ptrptr = True
        return t
    t.ptr = True
    return t


def _cxxConstPtrToType(ipdltype, side):
    t = _cxxBareType(ipdltype, side)
    if ipdltype.isIPDL() and ipdltype.isActor() and side is not None:
        t.ptr = False
        t.ptrconstptr = True
        return t
    t.const = True
    t.ptr = True
    return t


def _cxxInType(ipdltype, side, direction):
    t = _cxxBareType(ipdltype, side)
    if ipdltype.isIPDL() and ipdltype.isActor():
        return t
    if ipdltype.isIPDL() and ipdltype.isNotNull():
        # If the inner type chooses to use a raw pointer, wrap that instead.
        inner = _cxxInType(ipdltype.basetype, side, direction)
        if inner.ptr:
            t = _cxxNotNullType(inner)
            return t
    if _cxxTypeNeedsMoveForSend(ipdltype):
        t.rvalref = True
        return t
    if ipdltype.isCxx():
        if ipdltype.isRefcounted():
            # Use T* instead of const RefPtr<T>&
            t = t.T
            t.ptr = True
            return t
        if ipdltype.name() == "nsCString":
            t = Type("nsACString")
        if ipdltype.name() == "nsString":
            t = Type("nsAString")
    # Use Span<T const> rather than nsTArray<T> for array types which aren't
    # `_cxxTypeNeedsMoveForSend`. This is only done for the "send" side, and not
    # for recv signatures.
    if direction == "send" and ipdltype.isIPDL() and ipdltype.isArray():
        inner = _cxxBareType(ipdltype.basetype, side)
        return _cxxSpanType(inner)

    t.const = True
    t.ref = True
    return t


def _allocMethod(ptype, side):
    return "Alloc" + ptype.name() + side.title()


def _deallocMethod(ptype, side):
    return "Dealloc" + ptype.name() + side.title()


##
# A _HybridDecl straddles IPDL and C++ decls.  It knows which C++
# types correspond to which IPDL types, and it also knows how
# serialize and deserialize "special" IPDL C++ types.
##


class _HybridDecl:
    """A hybrid decl stores both an IPDL type and all the C++ type
    info needed by later passes, along with a basic name for the decl."""

    def __init__(self, ipdltype, name, attributes={}):
        self.ipdltype = ipdltype
        self.name = name
        self.attributes = attributes

    def var(self):
        return ExprVar(self.name)

    def bareType(self, side, fq=False):
        """Return this decl's unqualified C++ type."""
        return _cxxBareType(self.ipdltype, side, fq=fq)

    def refType(self, side):
        """Return this decl's C++ type as a 'reference' type, which is not
        necessarily a C++ reference."""
        return _cxxRefType(self.ipdltype, side)

    def constRefType(self, side):
        """Return this decl's C++ type as a const, 'reference' type."""
        return _cxxConstRefType(self.ipdltype, side)

    def ptrToType(self, side):
        return _cxxPtrToType(self.ipdltype, side)

    def constPtrToType(self, side):
        return _cxxConstPtrToType(self.ipdltype, side)

    def inType(self, side, direction):
        """Return this decl's C++ Type with sending inparam semantics."""
        return _cxxInType(self.ipdltype, side, direction)

    def outType(self, side):
        """Return this decl's C++ Type with outparam semantics."""
        t = self.bareType(side)
        if self.ipdltype.isIPDL() and self.ipdltype.isActor():
            t.ptr = False
            t.ptrptr = True
            return t
        t.ptr = True
        return t

    def forceMoveType(self, side):
        """Return this decl's C++ Type with forced move semantics."""
        assert _cxxTypeCanMove(self.ipdltype)
        return _cxxForceMoveRefType(self.ipdltype, side)


# --------------------------------------------------


class HasFQName:
    def fqClassName(self):
        return self.decl.type.fullname()


class _CompoundTypeComponent(_HybridDecl):
    # @override the following methods to make the side argument optional.
    def bareType(self, side=None, fq=False):
        return _HybridDecl.bareType(self, side, fq=fq)

    def refType(self, side=None):
        return _HybridDecl.refType(self, side)

    def constRefType(self, side=None):
        return _HybridDecl.constRefType(self, side)

    def ptrToType(self, side=None):
        return _HybridDecl.ptrToType(self, side)

    def constPtrToType(self, side=None):
        return _HybridDecl.constPtrToType(self, side)

    def forceMoveType(self, side=None):
        return _HybridDecl.forceMoveType(self, side)


class StructDecl(ipdl.ast.StructDecl, HasFQName):
    def fields_ipdl_order(self):
        for f in self.fields:
            yield f

    def fields_member_order(self):
        assert len(self.packed_field_order) == len(self.fields)

        for i in self.packed_field_order:
            yield self.fields[i]

    @staticmethod
    def upgrade(structDecl):
        assert isinstance(structDecl, ipdl.ast.StructDecl)
        structDecl.__class__ = StructDecl


class _StructField(_CompoundTypeComponent):
    def __init__(self, ipdltype, name, sd):
        self.basename = name

        _CompoundTypeComponent.__init__(self, ipdltype, name)

    def getMethod(self, thisexpr=None, sel="."):
        meth = self.var()
        if thisexpr is not None:
            return ExprSelect(thisexpr, sel, meth.name)
        return meth

    def refExpr(self, thisexpr=None):
        ref = self.memberVar()
        if thisexpr is not None:
            ref = ExprSelect(thisexpr, ".", ref.name)
        return ref

    def constRefExpr(self, thisexpr=None):
        # sigh, gross hack
        refexpr = self.refExpr(thisexpr)
        if "Shmem" == self.ipdltype.name():
            refexpr = ExprCast(refexpr, Type("Shmem", ref=True), const=True)
        return refexpr

    def argVar(self):
        return ExprVar("_" + self.name)

    def memberVar(self):
        return ExprVar(self.name + "_")


class UnionDecl(ipdl.ast.UnionDecl, HasFQName):
    def callType(self, var=None):
        func = ExprVar("type")
        if var is not None:
            func = ExprSelect(var, ".", func.name)
        return ExprCall(func)

    @staticmethod
    def upgrade(unionDecl):
        assert isinstance(unionDecl, ipdl.ast.UnionDecl)
        unionDecl.__class__ = UnionDecl


class _UnionMember(_CompoundTypeComponent):
    """Not in the AFL sense, but rather a member (e.g. |int;|) of an
    IPDL union type."""

    def __init__(self, ipdltype, ud):
        flatname = _flatTypeName(ipdltype)

        _CompoundTypeComponent.__init__(self, ipdltype, "V" + flatname)
        self.flattypename = flatname

        # To create a finite object with a mutually recursive type, a union must
        # be present somewhere in the recursive loop. Because of that we only
        # need to care about introducing indirections inside unions.
        self.recursive = ud.decl.type.mutuallyRecursiveWith(ipdltype)

    def enum(self):
        return "T" + self.flattypename

    def enumvar(self):
        return ExprVar(self.enum())

    def internalType(self):
        if self.recursive:
            return self.ptrToType()
        else:
            return self.bareType()

    def unionType(self):
        """Type used for storage in generated C union decl."""
        if self.recursive:
            return self.ptrToType()
        else:
            return Type("mozilla::AlignedStorage2", T=self.internalType())

    def unionValue(self):
        # NB: knows that Union's storage C union is named |mValue|
        return ExprSelect(ExprVar("mValue"), ".", self.name)

    def typedef(self):
        return self.flattypename + "__tdef"

    def callGetConstPtr(self):
        """Return an expression of type self.constptrToSelfType()"""
        return ExprCall(ExprVar(self.getConstPtrName()))

    def callGetPtr(self):
        """Return an expression of type self.ptrToSelfType()"""
        return ExprCall(ExprVar(self.getPtrName()))

    def callCtor(self, expr=None):
        assert not isinstance(expr, list)

        if expr is None:
            args = None
        elif (
            self.ipdltype.isIPDL()
            and self.ipdltype.isArray()
            and not isinstance(expr, ExprMove)
        ):
            args = [ExprCall(ExprSelect(expr, ".", "Clone"), args=[])]
        else:
            args = [expr]

        if self.recursive:
            return ExprAssn(self.callGetPtr(), ExprNew(self.bareType(), args=args))
        else:
            return ExprNew(
                self.bareType(),
                args=args,
                newargs=[ExprVar("mozilla::KnownNotNull"), self.callGetPtr()],
            )

    def callDtor(self):
        if self.recursive:
            return ExprDelete(self.callGetPtr())
        else:
            return ExprCall(ExprSelect(self.callGetPtr(), "->", "~" + self.typedef()))

    def getTypeName(self):
        return "get_" + self.flattypename

    def getConstTypeName(self):
        return "get_" + self.flattypename

    def getOtherTypeName(self):
        return "get_" + self.otherflattypename

    def getPtrName(self):
        return "ptr_" + self.flattypename

    def getConstPtrName(self):
        return "constptr_" + self.flattypename

    def ptrToSelfExpr(self):
        """|*ptrToSelfExpr()| has type |self.bareType()|"""
        v = self.unionValue()
        if self.recursive:
            return v
        else:
            return ExprCall(ExprSelect(v, ".", "addr"))

    def constptrToSelfExpr(self):
        """|*constptrToSelfExpr()| has type |self.constType()|"""
        v = self.unionValue()
        if self.recursive:
            return v
        return ExprCall(ExprSelect(v, ".", "addr"))

    def ptrToInternalType(self):
        t = self.ptrToType()
        if self.recursive:
            t.ref = True
        return t

    def defaultValue(self, fq=False):
        # Use the default constructor for any class that does not have an
        # implicit copy constructor.
        if not self.bareType().hasimplicitcopyctor:
            return None

        if self.ipdltype.isIPDL() and self.ipdltype.isActor():
            return ExprLiteral.NULL
        # XXX sneaky here, maybe need ExprCtor()?
        return ExprCall(self.bareType(fq=fq))

    def getConstValue(self):
        v = ExprDeref(self.callGetConstPtr())
        # sigh
        if "Shmem" == self.ipdltype.name():
            v = ExprCast(v, Type("Shmem", ref=True), const=True)
        return v


# --------------------------------------------------


class MessageDecl(ipdl.ast.MessageDecl):
    def baseName(self):
        return self.name

    def recvMethod(self):
        name = _recvPrefix(self.decl.type) + self.baseName()
        if self.decl.type.isCtor():
            name += "Constructor"
        return name

    def sendMethod(self):
        name = _sendPrefix(self.decl.type) + self.baseName()
        if self.decl.type.isCtor():
            name += "Constructor"
        return name

    def hasReply(self):
        return (
            self.decl.type.hasReply()
            or self.decl.type.isCtor()
            or self.decl.type.isDtor()
        )

    def hasAsyncReturns(self):
        return self.decl.type.isAsync() and self.returns

    def msgCtorFunc(self):
        return "Msg_%s" % (self.decl.progname)

    def prettyMsgName(self, pfx=""):
        return pfx + self.msgCtorFunc()

    def pqMsgCtorFunc(self):
        return "%s::%s" % (self.namespace, self.msgCtorFunc())

    def msgId(self):
        return self.msgCtorFunc() + "__ID"

    def pqMsgId(self):
        return "%s::%s" % (self.namespace, self.msgId())

    def replyCtorFunc(self):
        return "Reply_%s" % (self.decl.progname)

    def pqReplyCtorFunc(self):
        return "%s::%s" % (self.namespace, self.replyCtorFunc())

    def replyId(self):
        return self.replyCtorFunc() + "__ID"

    def pqReplyId(self):
        return "%s::%s" % (self.namespace, self.replyId())

    def prettyReplyName(self, pfx=""):
        return pfx + self.replyCtorFunc()

    def promiseName(self):
        name = self.baseName()
        if self.decl.type.isCtor():
            name += "Constructor"
        name += "Promise"
        return name

    def resolverName(self):
        return self.baseName() + "Resolver"

    def actorDecl(self):
        return self.params[0]

    def makeCxxParams(
        self, paramsems="in", returnsems="out", side=None, implicit=True, direction=None
    ):
        """Return a list of C++ decls per the spec'd configuration.
        |params| and |returns| is the C++ semantics of those: 'in', 'out', or None."""

        def makeDecl(d, sems):
            if (
                self.decl.type.tainted
                and "NoTaint" not in d.attributes
                and direction == "recv"
            ):
                # Tainted types are passed by-value, allowing the receiver to move them if desired.
                assert sems != "out"
                return Decl(Type("Tainted", T=d.bareType(side)), d.name)

            if sems == "in":
                t = d.inType(side, direction)
                # If this is the `recv` side, and we're not using "move"
                # semantics, that means we're an alloc method, and cannot accept
                # values by rvalue reference. Downgrade to an lvalue reference.
                if direction == "recv" and t.rvalref:
                    t.rvalref = False
                    t.ref = True
                return Decl(t, d.name)
            elif sems == "move":
                assert direction == "recv"
                # For legacy reasons, use an rvalue reference when generating
                # parameters for recv methods which accept arrays.
                if d.ipdltype.isIPDL() and d.ipdltype.isArray():
                    t = d.bareType(side)
                    t.rvalref = True
                    return Decl(t, d.name)
                return Decl(d.inType(side, direction), d.name)
            elif sems == "out":
                return Decl(d.outType(side), d.name)
            else:
                assert 0

        def makeResolverDecl(returns):
            return Decl(Type(self.resolverName(), rvalref=True), "aResolve")

        def makeCallbackResolveDecl(returns):
            if len(returns) > 1:
                resolvetype = _tuple([d.bareType(side) for d in returns])
            else:
                resolvetype = returns[0].bareType(side)

            return Decl(
                Type("mozilla::ipc::ResolveCallback", T=resolvetype, rvalref=True),
                "aResolve",
            )

        def makeCallbackRejectDecl(returns):
            return Decl(Type("mozilla::ipc::RejectCallback", rvalref=True), "aReject")

        cxxparams = []
        if paramsems is not None:
            cxxparams.extend([makeDecl(d, paramsems) for d in self.params])

        if returnsems == "promise" and self.returns:
            pass
        elif returnsems == "callback" and self.returns:
            cxxparams.extend(
                [
                    makeCallbackResolveDecl(self.returns),
                    makeCallbackRejectDecl(self.returns),
                ]
            )
        elif returnsems == "resolver" and self.returns:
            cxxparams.extend([makeResolverDecl(self.returns)])
        elif returnsems is not None:
            cxxparams.extend([makeDecl(r, returnsems) for r in self.returns])

        if not implicit and self.decl.type.hasImplicitActorParam():
            cxxparams = cxxparams[1:]

        return cxxparams

    def makeCxxArgs(
        self, paramsems="in", retsems="out", retcallsems="out", implicit=True
    ):
        assert not retcallsems or retsems  # retcallsems => returnsems
        cxxargs = []

        if paramsems == "move":
            # We don't std::move() RefPtr<T> types because current Recv*()
            # implementors take these parameters as T*, and
            # std::move(RefPtr<T>) doesn't coerce to T*.
            # We also don't move NotNull, as it has no move constructor.
            cxxargs.extend(
                [
                    p.var()
                    if p.ipdltype.isRefcounted()
                    or (p.ipdltype.isIPDL() and p.ipdltype.isNotNull())
                    else ExprMove(p.var())
                    for p in self.params
                ]
            )
        elif paramsems == "in":
            cxxargs.extend([p.var() for p in self.params])
        else:
            assert False

        for ret in self.returns:
            if retsems == "in":
                if retcallsems == "in":
                    cxxargs.append(ret.var())
                elif retcallsems == "out":
                    cxxargs.append(ExprAddrOf(ret.var()))
                else:
                    assert 0
            elif retsems == "out":
                if retcallsems == "in":
                    cxxargs.append(ExprDeref(ret.var()))
                elif retcallsems == "out":
                    cxxargs.append(ret.var())
                else:
                    assert 0
            elif retsems == "resolver":
                pass
        if retsems == "resolver":
            cxxargs.append(ExprMove(ExprVar("resolver")))

        if not implicit:
            assert self.decl.type.hasImplicitActorParam()
            cxxargs = cxxargs[1:]

        return cxxargs

    @staticmethod
    def upgrade(messageDecl):
        assert isinstance(messageDecl, ipdl.ast.MessageDecl)
        if messageDecl.decl.type.hasImplicitActorParam():
            messageDecl.params.insert(
                0,
                _HybridDecl(
                    ipdl.type.ActorType(messageDecl.decl.type.constructedType()),
                    "actor",
                ),
            )
        messageDecl.__class__ = MessageDecl


# --------------------------------------------------
def _usesShmem(p):
    for md in p.messageDecls:
        for param in md.inParams:
            if ipdl.type.hasshmem(param.type):
                return True
        for ret in md.outParams:
            if ipdl.type.hasshmem(ret.type):
                return True
    return False


def _subtreeUsesShmem(p):
    if _usesShmem(p):
        return True

    ptype = p.decl.type
    for mgd in ptype.manages:
        if ptype is not mgd:
            if _subtreeUsesShmem(mgd._ast):
                return True
    return False


class Protocol(ipdl.ast.Protocol):
    def managerInterfaceType(self, ptr=False):
        return Type("mozilla::ipc::IProtocol", ptr=ptr)

    def openedProtocolInterfaceType(self, ptr=False):
        return Type("mozilla::ipc::IToplevelProtocol", ptr=ptr)

    def _ipdlmgrtype(self):
        assert 1 == len(self.decl.type.managers)
        for mgr in self.decl.type.managers:
            return mgr

    def managerActorType(self, side, ptr=False):
        return Type(_actorName(self._ipdlmgrtype().name(), side), ptr=ptr)

    def unregisterMethod(self, actorThis=None):
        if actorThis is not None:
            return ExprSelect(actorThis, "->", "Unregister")
        return ExprVar("Unregister")

    def removeManageeMethod(self):
        return ExprVar("RemoveManagee")

    def deallocManageeMethod(self):
        return ExprVar("DeallocManagee")

    def getChannelMethod(self):
        return ExprVar("GetIPCChannel")

    def callGetChannel(self, actorThis=None):
        fn = self.getChannelMethod()
        if actorThis is not None:
            fn = ExprSelect(actorThis, "->", fn.name)
        return ExprCall(fn)

    def processingErrorVar(self):
        assert self.decl.type.isToplevel()
        return ExprVar("ProcessingError")

    def shouldContinueFromTimeoutVar(self):
        assert self.decl.type.isToplevel()
        return ExprVar("ShouldContinueFromReplyTimeout")

    def routingId(self, actorThis=None):
        if self.decl.type.isToplevel():
            return ExprVar("MSG_ROUTING_CONTROL")
        if actorThis is not None:
            return ExprCall(ExprSelect(actorThis, "->", "Id"))
        return ExprCall(ExprVar("Id"))

    def managerVar(self, thisexpr=None):
        assert thisexpr is not None or not self.decl.type.isToplevel()
        mvar = ExprCall(ExprVar("Manager"), args=[])
        if thisexpr is not None:
            mvar = ExprCall(ExprSelect(thisexpr, "->", "Manager"), args=[])
        return mvar

    def managedCxxType(self, actortype, side):
        assert self.decl.type.isManagerOf(actortype)
        return Type(_actorName(actortype.name(), side), ptr=True)

    def managedMethod(self, actortype, side):
        assert self.decl.type.isManagerOf(actortype)
        return ExprVar("Managed" + _actorName(actortype.name(), side))

    def managedVar(self, actortype, side):
        assert self.decl.type.isManagerOf(actortype)
        return ExprVar("mManaged" + _actorName(actortype.name(), side))

    def managedVarType(self, actortype, side, const=False, ref=False):
        assert self.decl.type.isManagerOf(actortype)
        return _cxxManagedContainerType(
            Type(_actorName(actortype.name(), side)), const=const, ref=ref
        )

    def subtreeUsesShmem(self):
        return _subtreeUsesShmem(self)

    @staticmethod
    def upgrade(protocol):
        assert isinstance(protocol, ipdl.ast.Protocol)
        protocol.__class__ = Protocol


class TranslationUnit(ipdl.ast.TranslationUnit):
    @staticmethod
    def upgrade(tu):
        assert isinstance(tu, ipdl.ast.TranslationUnit)
        tu.__class__ = TranslationUnit


# -----------------------------------------------------------------------------

pod_types = {
    "::int8_t": 1,
    "::uint8_t": 1,
    "::int16_t": 2,
    "::uint16_t": 2,
    "::int32_t": 4,
    "::uint32_t": 4,
    "::int64_t": 8,
    "::uint64_t": 8,
    "float": 4,
    "double": 8,
}
max_pod_size = max(pod_types.values())
# We claim that all types we don't recognize are automatically "bigger"
# than pod types for ease of sorting.
pod_size_sentinel = max_pod_size * 2


def pod_size(ipdltype):
    if not ipdltype.isCxx():
        return pod_size_sentinel

    return pod_types.get(ipdltype.fullname(), pod_size_sentinel)


class _DecorateWithCxxStuff(ipdl.ast.Visitor):
    """Phase 1 of lowering: decorate the IPDL AST with information
    relevant to C++ code generation.

    This pass results in an AST that is a poor man's "IR"; in reality, a
    "hybrid" AST mainly consisting of IPDL nodes with new C++ info along
    with some new IPDL/C++ nodes that are tuned for C++ codegen."""

    def __init__(self):
        self.visitedTus = set()
        self.protocolName = None

    def visitTranslationUnit(self, tu):
        if tu not in self.visitedTus:
            self.visitedTus.add(tu)
            ipdl.ast.Visitor.visitTranslationUnit(self, tu)
            if not isinstance(tu, TranslationUnit):
                TranslationUnit.upgrade(tu)

    def visitInclude(self, inc):
        if inc.tu.filetype == "header":
            inc.tu.accept(self)

    def visitProtocol(self, pro):
        self.protocolName = pro.name
        Protocol.upgrade(pro)
        return ipdl.ast.Visitor.visitProtocol(self, pro)

    def visitStructDecl(self, sd):
        if not isinstance(sd, StructDecl):
            newfields = [_StructField(f.decl.type, f.name, sd) for f in sd.fields]

            # Compute a permutation of the fields for in-memory storage such
            # that the memory layout of the structure will be well-packed.
            permutation = list(range(len(newfields)))

            # Note that the results of `pod_size` ensure that non-POD fields
            # sort before POD ones.
            def size(idx):
                return pod_size(newfields[idx].ipdltype)

            permutation.sort(key=size, reverse=True)

            sd.fields = newfields
            sd.packed_field_order = permutation
            StructDecl.upgrade(sd)

    def visitUnionDecl(self, ud):
        ud.components = [_UnionMember(ctype, ud) for ctype in ud.decl.type.components]
        UnionDecl.upgrade(ud)

    def visitDecl(self, decl):
        return _HybridDecl(decl.type, decl.progname, decl.attributes)

    def visitMessageDecl(self, md):
        md.namespace = self.protocolName
        md.params = [param.accept(self) for param in md.inParams]
        md.returns = [ret.accept(self) for ret in md.outParams]
        MessageDecl.upgrade(md)


# -----------------------------------------------------------------------------


def msgenums(protocol, pretty=False):
    msgenum = TypeEnum("MessageType")
    msgstart = _messageStartName(protocol.decl.type) + " << 16"
    msgenum.addId(protocol.name + "Start", msgstart)

    for md in protocol.messageDecls:
        msgenum.addId(md.prettyMsgName() if pretty else md.msgId())
        if md.hasReply():
            msgenum.addId(md.prettyReplyName() if pretty else md.replyId())

    msgenum.addId(protocol.name + "End")
    return msgenum


class _GenerateProtocolCode(ipdl.ast.Visitor):
    """Creates code common to both the parent and child actors."""

    def __init__(self):
        self.protocol = None  # protocol we're generating a class for
        self.hdrfile = None  # what will become Protocol.h
        self.cppfile = None  # what will become Protocol.cpp
        self.cppIncludeHeaders = []
        self.structUnionDefns = []
        self.funcDefns = []

    def lower(self, tu, cxxHeaderFile, cxxFile, segmentcapacitydict):
        self.protocol = tu.protocol
        self.hdrfile = cxxHeaderFile
        self.cppfile = cxxFile
        self.segmentcapacitydict = segmentcapacitydict
        tu.accept(self)

    def visitTranslationUnit(self, tu):
        hf = self.hdrfile

        hf.addthing(_DISCLAIMER)
        hf.addthings(_includeGuardStart(hf))
        hf.addthing(Whitespace.NL)

        for inc in builtinHeaderIncludes:
            self.visitBuiltinCxxInclude(inc)

        # Compute the set of includes we need for declared structure/union
        # classes for this protocol.
        typesToIncludes = {}
        for using in tu.using:
            typestr = str(using.type)
            if typestr not in typesToIncludes:
                typesToIncludes[typestr] = using.header
            else:
                assert typesToIncludes[typestr] == using.header

        aggregateTypeIncludes = set()
        for su in tu.structsAndUnions:
            typedeps = _ComputeTypeDeps(su.decl.type, typesToIncludes)
            if isinstance(su, ipdl.ast.StructDecl):
                aggregateTypeIncludes.add("mozilla/ipc/IPDLStructMember.h")
                for f in su.fields:
                    f.ipdltype.accept(typedeps)
            elif isinstance(su, ipdl.ast.UnionDecl):
                for c in su.components:
                    c.ipdltype.accept(typedeps)

            aggregateTypeIncludes.update(typedeps.includeHeaders)

        if len(aggregateTypeIncludes) != 0:
            hf.addthing(Whitespace.NL)
            hf.addthings([Whitespace("// Headers for typedefs"), Whitespace.NL])

            for headername in sorted(iter(aggregateTypeIncludes)):
                hf.addthing(CppDirective("include", '"' + headername + '"'))

        # Manually run Visitor.visitTranslationUnit. For dependency resolution
        # we need to handle structs and unions separately.
        for cxxInc in tu.cxxIncludes:
            cxxInc.accept(self)
        for inc in tu.includes:
            inc.accept(self)
        self.generateStructsAndUnions(tu)
        for using in tu.builtinUsing:
            using.accept(self)
        for using in tu.using:
            using.accept(self)
        if tu.protocol:
            tu.protocol.accept(self)

        if tu.filetype == "header":
            self.cppIncludeHeaders.append(_ipdlhHeaderName(tu) + ".h")

        hf.addthing(Whitespace.NL)
        hf.addthings(_includeGuardEnd(hf))

        cf = self.cppfile
        cf.addthings(
            (
                [_DISCLAIMER, Whitespace.NL]
                + [
                    CppDirective("include", '"' + h + '"')
                    for h in self.cppIncludeHeaders
                ]
                + [Whitespace.NL]
                + [
                    CppDirective("include", '"%s"' % filename)
                    for filename in ipdl.builtin.CppIncludes
                ]
                + [Whitespace.NL]
            )
        )

        if self.protocol:
            # construct the namespace into which we'll stick all our defns
            ns = Namespace(self.protocol.name)
            cf.addthing(_putInNamespaces(ns, self.protocol.namespaces))
            ns.addstmts(([Whitespace.NL] + self.funcDefns + [Whitespace.NL]))

        cf.addthings(self.structUnionDefns)

    def visitBuiltinCxxInclude(self, inc):
        self.hdrfile.addthing(CppDirective("include", '"' + inc.file + '"'))

    def visitCxxInclude(self, inc):
        self.cppIncludeHeaders.append(inc.file)

    def visitInclude(self, inc):
        if inc.tu.filetype == "header":
            self.hdrfile.addthing(
                CppDirective("include", '"' + _ipdlhHeaderName(inc.tu) + '.h"')
            )
            # Inherit cpp includes defined by imported header files, as they may
            # be required to serialize an imported `using` type.
            for cxxinc in inc.tu.cxxIncludes:
                cxxinc.accept(self)
        else:
            self.cppIncludeHeaders += [
                _protocolHeaderName(inc.tu.protocol, "parent") + ".h",
                _protocolHeaderName(inc.tu.protocol, "child") + ".h",
            ]

    def generateStructsAndUnions(self, tu):
        """Generate the definitions for all structs and unions. This will
        re-order the declarations if needed in the C++ code such that
        dependencies have already been defined."""
        decls = OrderedDict()
        for su in tu.structsAndUnions:
            if isinstance(su, StructDecl):
                which = "struct"
                forwarddecls, fulldecltypes, cls = _generateCxxStruct(su)
                traitsdecl, traitsdefns = _ParamTraits.structPickling(su.decl.type)
            else:
                assert isinstance(su, UnionDecl)
                which = "union"
                forwarddecls, fulldecltypes, cls = _generateCxxUnion(su)
                traitsdecl, traitsdefns = _ParamTraits.unionPickling(su.decl.type)

            clsdecl, methoddefns = _splitClassDeclDefn(cls)

            # Store the declarations in the decls map so we can emit in
            # dependency order.
            decls[su.decl.type] = (
                fulldecltypes,
                [Whitespace.NL]
                + forwarddecls
                + [
                    Whitespace(
                        """
//-----------------------------------------------------------------------------
// Declaration of the IPDL type |%s %s|
//
"""
                        % (which, su.name)
                    ),
                    _putInNamespaces(clsdecl, su.namespaces),
                ]
                + [Whitespace.NL, traitsdecl],
            )

            self.structUnionDefns.extend(
                [
                    Whitespace(
                        """
//-----------------------------------------------------------------------------
// Method definitions for the IPDL type |%s %s|
//
"""
                        % (which, su.name)
                    ),
                    _putInNamespaces(methoddefns, su.namespaces),
                    Whitespace.NL,
                    traitsdefns,
                ]
            )

        # Generate the declarations structs in dependency order.
        def gen_struct(deps, defn):
            for dep in deps:
                if dep in decls:
                    d, t = decls[dep]
                    del decls[dep]
                    gen_struct(d, t)
            self.hdrfile.addthings(defn)

        while len(decls) > 0:
            _, (d, t) = decls.popitem(False)
            gen_struct(d, t)

    def visitProtocol(self, p):
        self.cppIncludeHeaders.append(_protocolHeaderName(self.protocol, "") + ".h")
        self.cppIncludeHeaders.append(
            _protocolHeaderName(self.protocol, "Parent") + ".h"
        )
        self.cppIncludeHeaders.append(
            _protocolHeaderName(self.protocol, "Child") + ".h"
        )

        # Forward declare our own actors.
        self.hdrfile.addthings(
            [
                Whitespace.NL,
                _makeForwardDeclForActor(p.decl.type, "Parent"),
                _makeForwardDeclForActor(p.decl.type, "Child"),
            ]
        )

        self.hdrfile.addthing(
            Whitespace(
                """
//-----------------------------------------------------------------------------
// Code common to %sChild and %sParent
//
"""
                % (p.name, p.name)
            )
        )

        # construct the namespace into which we'll stick all our decls
        ns = Namespace(self.protocol.name)
        self.hdrfile.addthing(_putInNamespaces(ns, p.namespaces))
        ns.addstmt(Whitespace.NL)

        for func in self.genEndpointFuncs():
            edecl, edefn = _splitFuncDeclDefn(func)
            ns.addstmts([edecl, Whitespace.NL])
            self.funcDefns.append(edefn)

        # spit out message type enum and classes
        msgenum = msgenums(self.protocol)
        ns.addstmts([StmtDecl(Decl(msgenum, "")), Whitespace.NL])

        for md in p.messageDecls:
            decls = []

            # Look up the segment capacity used for serializing this
            # message. If the capacity is not specified, use '0' for
            # the default capacity (defined in ipc_message.cc)
            name = "%s::%s" % (md.namespace, md.decl.progname)
            segmentcapacity = self.segmentcapacitydict.get(name, 0)

            mfDecl, mfDefn = _splitFuncDeclDefn(
                _generateMessageConstructor(md, segmentcapacity, p, forReply=False)
            )
            decls.append(mfDecl)
            self.funcDefns.append(mfDefn)

            if md.hasReply():
                rfDecl, rfDefn = _splitFuncDeclDefn(
                    _generateMessageConstructor(md, 0, p, forReply=True)
                )
                decls.append(rfDecl)
                self.funcDefns.append(rfDefn)

            decls.append(Whitespace.NL)
            ns.addstmts(decls)

        ns.addstmts([Whitespace.NL, Whitespace.NL])

    # Generate code for PFoo::CreateEndpoints.
    def genEndpointFuncs(self):
        p = self.protocol.decl.type
        tparent = _cxxBareType(ActorType(p), "Parent", fq=True)
        tchild = _cxxBareType(ActorType(p), "Child", fq=True)

        def mkOverload(includepids):
            params = []
            if includepids:
                params = [
                    Decl(Type("base::ProcessId"), "aParentDestPid"),
                    Decl(Type("base::ProcessId"), "aChildDestPid"),
                ]
            params += [
                Decl(
                    Type("mozilla::ipc::Endpoint<" + tparent.name + ">", ptr=True),
                    "aParent",
                ),
                Decl(
                    Type("mozilla::ipc::Endpoint<" + tchild.name + ">", ptr=True),
                    "aChild",
                ),
            ]
            openfunc = MethodDefn(
                MethodDecl("CreateEndpoints", params=params, ret=Type.NSRESULT)
            )
            openfunc.addcode(
                """
                return mozilla::ipc::CreateEndpoints(
                    mozilla::ipc::PrivateIPDLInterface(),
                    $,{args});
                """,
                args=[ExprVar(d.name) for d in params],
            )
            return openfunc

        funcs = [mkOverload(True)]
        if not p.hasOtherPid():
            funcs.append(mkOverload(False))
        return funcs


# --------------------------------------------------

cppPriorityList = list(
    map(lambda src: src.upper() + "_PRIORITY", ipdl.ast.priorityList)
)


def _generateMessageConstructor(md, segmentSize, protocol, forReply=False):
    if forReply:
        clsname = md.replyCtorFunc()
        msgid = md.replyId()
        replyEnum = "REPLY"
        prioEnum = cppPriorityList[md.decl.type.replyPrio]
    else:
        clsname = md.msgCtorFunc()
        msgid = md.msgId()
        replyEnum = "NOT_REPLY"
        prioEnum = cppPriorityList[md.decl.type.prio]

    nested = md.decl.type.nested
    compress = md.decl.type.compress
    lazySend = md.decl.type.lazySend

    routingId = ExprVar("routingId")

    func = FunctionDefn(
        FunctionDecl(
            clsname,
            params=[Decl(Type("int32_t"), routingId.name)],
            ret=Type("mozilla::UniquePtr<IPC::Message>"),
        )
    )

    if not compress:
        compression = "COMPRESSION_NONE"
    elif compress.value == "all":
        compression = "COMPRESSION_ALL"
    else:
        assert compress.value is None
        compression = "COMPRESSION_ENABLED"

    if lazySend:
        lazySendEnum = "LAZY_SEND"
    else:
        lazySendEnum = "EAGER_SEND"

    if nested == ipdl.ast.NOT_NESTED:
        nestedEnum = "NOT_NESTED"
    elif nested == ipdl.ast.INSIDE_SYNC_NESTED:
        nestedEnum = "NESTED_INSIDE_SYNC"
    else:
        assert nested == ipdl.ast.INSIDE_CPOW_NESTED
        nestedEnum = "NESTED_INSIDE_CPOW"

    if md.decl.type.isSync():
        syncEnum = "SYNC"
    else:
        syncEnum = "ASYNC"

    # FIXME(bug ???) - remove support for interrupt messages from the IPDL compiler.
    if md.decl.type.isInterrupt():
        func.addcode(
            """
            static_assert(
                false,
                "runtime support for intr messages has been removed from IPDL");
            """
        )

    if md.decl.type.isCtor():
        ctorEnum = "CONSTRUCTOR"
    else:
        ctorEnum = "NOT_CONSTRUCTOR"

    def messageEnum(valname):
        return ExprVar("IPC::Message::" + valname)

    flags = ExprCall(
        ExprVar("IPC::Message::HeaderFlags"),
        args=[
            messageEnum(nestedEnum),
            messageEnum(prioEnum),
            messageEnum(compression),
            messageEnum(lazySendEnum),
            messageEnum(ctorEnum),
            messageEnum(syncEnum),
            messageEnum(replyEnum),
        ],
    )

    segmentSize = int(segmentSize)
    if not segmentSize:
        segmentSize = 0
    func.addstmt(
        StmtReturn(
            ExprCall(
                ExprVar("IPC::Message::IPDLMessage"),
                args=[
                    routingId,
                    ExprVar(msgid),
                    ExprLiteral.Int(int(segmentSize)),
                    flags,
                ],
            )
        )
    )

    return func


# --------------------------------------------------


class _ParamTraits:
    var = ExprVar("aVar")
    writervar = ExprVar("aWriter")
    readervar = ExprVar("aReader")

    @classmethod
    def ifsideis(cls, rdrwtr, side, then, els=None):
        cxxside = ExprVar("mozilla::ipc::ChildSide")
        if side == "parent":
            cxxside = ExprVar("mozilla::ipc::ParentSide")

        ifstmt = StmtIf(
            ExprBinary(
                cxxside,
                "==",
                ExprCode("${rdrwtr}->GetActor()->GetSide()", rdrwtr=rdrwtr),
            )
        )
        ifstmt.addifstmt(then)
        if els is not None:
            ifstmt.addelsestmt(els)
        return ifstmt

    @classmethod
    def fatalError(cls, rdrwtr, reason):
        return StmtCode(
            "${rdrwtr}->FatalError(${reason});",
            rdrwtr=rdrwtr,
            reason=ExprLiteral.String(reason),
        )

    @classmethod
    def writeSentinel(cls, writervar, sentinelKey):
        return [
            Whitespace("// Sentinel = " + repr(sentinelKey) + "\n", indent=True),
            StmtExpr(
                ExprCall(
                    ExprSelect(writervar, "->", "WriteSentinel"),
                    args=[ExprLiteral.Int(hashfunc(sentinelKey))],
                )
            ),
        ]

    @classmethod
    def readSentinel(cls, readervar, sentinelKey, sentinelFail):
        # Read the sentinel
        read = ExprCall(
            ExprSelect(readervar, "->", "ReadSentinel"),
            args=[ExprLiteral.Int(hashfunc(sentinelKey))],
        )
        ifsentinel = StmtIf(ExprNot(read))
        ifsentinel.addifstmts(sentinelFail)

        return [
            Whitespace("// Sentinel = " + repr(sentinelKey) + "\n", indent=True),
            ifsentinel,
        ]

    @classmethod
    def write(cls, var, writervar, ipdltype=None):
        if ipdltype and _cxxTypeNeedsMoveForSend(ipdltype):
            var = ExprMove(var)
        return ExprCall(ExprVar("IPC::WriteParam"), args=[writervar, var])

    @classmethod
    def checkedWrite(cls, ipdltype, var, writervar, sentinelKey):
        assert sentinelKey
        block = Block()

        block.addstmts(
            [
                StmtExpr(cls.write(var, writervar, ipdltype)),
            ]
        )
        block.addstmts(cls.writeSentinel(writervar, sentinelKey))
        return block

    @classmethod
    def bulkSentinelKey(cls, fields):
        return " | ".join(f.basename for f in fields)

    @classmethod
    def checkedBulkWrite(cls, var, size, fields):
        block = Block()
        first = fields[0]

        block.addstmts(
            [
                StmtExpr(
                    ExprCall(
                        ExprSelect(cls.writervar, "->", "WriteBytes"),
                        args=[
                            ExprAddrOf(
                                ExprCall(first.getMethod(thisexpr=var, sel="."))
                            ),
                            ExprLiteral.Int(size * len(fields)),
                        ],
                    )
                )
            ]
        )
        block.addstmts(cls.writeSentinel(cls.writervar, cls.bulkSentinelKey(fields)))

        return block

    @classmethod
    def checkedBulkRead(cls, var, size, fields):
        block = Block()
        first = fields[0]

        readbytes = ExprCall(
            ExprSelect(cls.readervar, "->", "ReadBytesInto"),
            args=[
                ExprAddrOf(ExprCall(first.getMethod(thisexpr=var, sel="->"))),
                ExprLiteral.Int(size * len(fields)),
            ],
        )
        ifbad = StmtIf(ExprNot(readbytes))
        errmsg = "Error bulk reading fields from %s" % first.ipdltype.name()
        ifbad.addifstmts(
            [cls.fatalError(cls.readervar, errmsg), StmtReturn(readResultError())]
        )
        block.addstmt(ifbad)
        block.addstmts(
            cls.readSentinel(
                cls.readervar,
                cls.bulkSentinelKey(fields),
                errfnSentinel(readResultError())(errmsg),
            )
        )

        return block

    @classmethod
    def checkedRead(
        cls,
        ipdltype,
        cxxtype,
        var,
        readervar,
        errfn,
        paramtype,
        sentinelKey,
        errfnSentinel,
    ):
        assert isinstance(var, ExprVar)

        if not isinstance(paramtype, list):
            paramtype = ["Error deserializing " + paramtype]

        block = Block()

        # Read the data
        block.addcode(
            """
            auto ${maybevar} = IPC::ReadParam<${ty}>(${reader});
            if (!${maybevar}) {
                $*{errfn}
            }
            auto& ${var} = *${maybevar};
            """,
            maybevar=ExprVar("maybe__" + var.name),
            ty=cxxtype,
            reader=readervar,
            errfn=errfn(*paramtype),
            var=var,
        )

        block.addstmts(
            cls.readSentinel(readervar, sentinelKey, errfnSentinel(*paramtype))
        )

        return block

    # Helper wrapper for checkedRead for use within _ParamTraits
    @classmethod
    def _checkedRead(cls, ipdltype, cxxtype, var, sentinelKey, what):
        def errfn(msg):
            return [cls.fatalError(cls.readervar, msg), StmtReturn(readResultError())]

        return cls.checkedRead(
            ipdltype,
            cxxtype,
            var,
            cls.readervar,
            errfn=errfn,
            paramtype=what,
            sentinelKey=sentinelKey,
            errfnSentinel=errfnSentinel(readResultError()),
        )

    @classmethod
    def generateDecl(cls, fortype, write, read, needsmove=False):
        # ParamTraits impls are selected ignoring constness, and references.
        pt = Class(
            "ParamTraits",
            specializes=Type(
                fortype.name, T=fortype.T, inner=fortype.inner, ptr=fortype.ptr
            ),
            struct=True,
        )

        # typedef T paramType;
        pt.addstmt(Typedef(fortype, "paramType"))

        # static void Write(Message*, const T&);
        if needsmove:
            intype = Type("paramType", rvalref=True)
        else:
            intype = Type("paramType", ref=True, const=True)
        writemthd = MethodDefn(
            MethodDecl(
                "Write",
                params=[
                    Decl(Type("IPC::MessageWriter", ptr=True), cls.writervar.name),
                    Decl(intype, cls.var.name),
                ],
                methodspec=MethodSpec.STATIC,
            )
        )
        writemthd.addstmts(write)
        pt.addstmt(writemthd)

        # static ReadResult<T> Read(MessageReader*);
        readmthd = MethodDefn(
            MethodDecl(
                "Read",
                params=[
                    Decl(Type("IPC::MessageReader", ptr=True), cls.readervar.name),
                ],
                ret=Type("IPC::ReadResult<paramType>"),
                methodspec=MethodSpec.STATIC,
            )
        )
        readmthd.addstmts(read)
        pt.addstmt(readmthd)

        # Split the class into declaration and definition
        clsdecl, methoddefns = _splitClassDeclDefn(pt)

        namespaces = [Namespace("IPC")]
        clsns = _putInNamespaces(clsdecl, namespaces)
        defns = _putInNamespaces(methoddefns, namespaces)
        return clsns, defns

    @classmethod
    def actorPickling(cls, actortype, side):
        """Generates pickling for IPDL actors. This is a |nullable| deserializer.
        Write and read callers will perform nullability validation."""

        cxxtype = _cxxBareType(actortype, side, fq=True)

        write = StmtCode(
            """
            MOZ_RELEASE_ASSERT(
                ${writervar}->GetActor(),
                "Cannot serialize managed actors without an actor");

            int32_t id;
            if (!${var}) {
                id = 0;  // kNullActorId
            } else {
                id = ${var}->Id();
                if (id == 1) {  // kFreedActorId
                    ${var}->FatalError("Actor has been |delete|d");
                }
                MOZ_RELEASE_ASSERT(
                    ${writervar}->GetActor()->GetIPCChannel() == ${var}->GetIPCChannel(),
                    "Actor must be from the same channel as the"
                    " actor it's being sent over");
                MOZ_RELEASE_ASSERT(
                    ${var}->CanSend(),
                    "Actor must still be open when sending");
            }

            ${write};
            """,
            var=cls.var,
            writervar=cls.writervar,
            write=cls.write(ExprVar("id"), cls.writervar),
        )

        # bool Read(..) impl
        read = StmtCode(
            """
            MOZ_RELEASE_ASSERT(
                ${readervar}->GetActor(),
                "Cannot deserialize managed actors without an actor");
            mozilla::Maybe<mozilla::ipc::IProtocol*> actor = ${readervar}->GetActor()
              ->ReadActor(${readervar}, true, ${actortype}, ${protocolid});
            if (actor.isSome()) {
                return static_cast<${cxxtype}>(actor.ref());
            }
            return {};
            """,
            readervar=cls.readervar,
            actortype=ExprLiteral.String(actortype.name()),
            protocolid=_protocolId(actortype),
            cxxtype=cxxtype,
        )

        return cls.generateDecl(cxxtype, [write], [read])

    @classmethod
    def structPickling(cls, structtype):
        sd = structtype._ast
        # NOTE: Not using _cxxBareType here as we don't have a side
        cxxtype = Type(structtype.fullname())

        write = []
        read = []

        # First serialize/deserialize all non-pod data in IPDL order. These need
        # to be read/written first because they'll be used to invoke the IPDL
        # struct's constructor.
        ctorargs = []
        for f in sd.fields_ipdl_order():
            if pod_size(f.ipdltype) == pod_size_sentinel:
                write.append(
                    cls.checkedWrite(
                        f.ipdltype,
                        ExprCall(f.getMethod(thisexpr=cls.var, sel=".")),
                        cls.writervar,
                        sentinelKey=f.basename,
                    )
                )
                read.append(
                    cls._checkedRead(
                        f.ipdltype,
                        f.bareType(fq=True),
                        f.argVar(),
                        f.basename,
                        "'"
                        + f.getMethod().name
                        + "' "
                        + "("
                        + f.ipdltype.name()
                        + ") member of "
                        + "'"
                        + structtype.name()
                        + "'",
                    )
                )
                if _cxxTypeCanMove(f.ipdltype):
                    ctorargs.append(ExprMove(f.argVar()))
                else:
                    ctorargs.append(f.argVar())
            else:
                # We're going to bulk-read in this value later, so we'll just
                # zero-initialize it for now.
                ctorargs.append(ExprCode("${type}{0}", type=f.bareType(fq=True)))

        resultvar = ExprVar("result__")
        read.append(
            StmtDecl(
                Decl(_cxxReadResultType(Type("paramType")), resultvar.name),
                initargs=[ExprVar("std::in_place")] + ctorargs,
            )
        )

        # After non-pod data, bulk read/write pod data in member order. This has
        # to be done after the result has been constructed, so that we have
        # somewhere to read into.
        for (size, fields) in itertools.groupby(
            sd.fields_member_order(), lambda f: pod_size(f.ipdltype)
        ):
            if size != pod_size_sentinel:
                fields = list(fields)
                write.append(cls.checkedBulkWrite(cls.var, size, fields))
                read.append(cls.checkedBulkRead(resultvar, size, fields))

        read.append(StmtReturn(resultvar))

        return cls.generateDecl(
            cxxtype, write, read, needsmove=_cxxTypeNeedsMoveForSend(structtype)
        )

    @classmethod
    def unionPickling(cls, uniontype):
        # NOTE: Not using _cxxBareType here as we don't have a side
        cxxtype = Type(uniontype.fullname())
        ud = uniontype._ast

        # Use typedef to set up an alias so it's easier to reference the struct type.
        alias = "union__"
        typevar = ExprVar("type")

        prelude = [
            Typedef(cxxtype, alias),
        ]

        writeswitch = StmtSwitch(typevar)
        write = prelude + [
            StmtDecl(Decl(Type.INT, typevar.name), init=ud.callType(cls.var)),
            cls.checkedWrite(
                None, typevar, cls.writervar, sentinelKey=uniontype.name()
            ),
            Whitespace.NL,
            writeswitch,
        ]

        readswitch = StmtSwitch(typevar)
        read = prelude + [
            cls._checkedRead(
                None,
                Type.INT,
                typevar,
                uniontype.name(),
                "type of union " + uniontype.name(),
            ),
            Whitespace.NL,
            readswitch,
        ]

        for c in ud.components:
            caselabel = CaseLabel(alias + "::" + c.enum())
            origenum = c.enum()

            writecase = StmtBlock()
            wstmt = cls.checkedWrite(
                c.ipdltype,
                ExprCall(ExprSelect(cls.var, ".", c.getTypeName())),
                cls.writervar,
                sentinelKey=c.enum(),
            )
            writecase.addstmts([wstmt, StmtReturn()])
            writeswitch.addcase(caselabel, writecase)

            readcase = StmtBlock()
            tmpvar = ExprVar("tmp")
            readcase.addstmts(
                [
                    cls._checkedRead(
                        c.ipdltype,
                        c.bareType(fq=True),
                        tmpvar,
                        origenum,
                        "variant " + origenum + " of union " + uniontype.name(),
                    ),
                    StmtReturn(ExprMove(tmpvar)),
                ]
            )
            readswitch.addcase(caselabel, readcase)

        # Add the error default case
        writeswitch.addcase(
            DefaultLabel(),
            StmtBlock(
                [
                    cls.fatalError(
                        cls.writervar, "unknown variant of union " + uniontype.name()
                    ),
                    StmtReturn(),
                ]
            ),
        )
        readswitch.addcase(
            DefaultLabel(),
            StmtBlock(
                [
                    cls.fatalError(
                        cls.readervar, "unknown variant of union " + uniontype.name()
                    ),
                    StmtReturn(readResultError()),
                ]
            ),
        )

        return cls.generateDecl(
            cxxtype, write, read, needsmove=_cxxTypeNeedsMoveForSend(uniontype)
        )


# --------------------------------------------------


class _ComputeTypeDeps(TypeVisitor):
    """Pass that gathers the C++ types that a particular IPDL type
    (recursively) depends on.  There are three kinds of dependencies: (i)
    types that need forward declaration; (ii) types that need a |using|
    stmt; (iii) IPDL structs or unions which must be fully declared
    before this struct.  Some types generate multiple kinds."""

    def __init__(self, fortype, typesToIncludes=None):
        ipdl.type.TypeVisitor.__init__(self)
        self.usingTypedefs = []
        self.forwardDeclStmts = []
        self.fullDeclTypes = []
        self.includeHeaders = set()
        self.fortype = fortype
        self.typesToIncludes = typesToIncludes

    def maybeTypedef(self, fqname, name, templateargs=[]):
        assert fqname.startswith("::")
        if fqname != name:
            self.usingTypedefs.append(Typedef(Type(fqname), name, templateargs))
        if self.typesToIncludes is not None and fqname in self.typesToIncludes:
            self.includeHeaders.add(self.typesToIncludes[fqname])

    def visitImportedCxxType(self, t):
        if t in self.visited:
            return
        self.visited.add(t)
        self.maybeTypedef(t.fullname(), t.name())

    def visitActorType(self, t):
        if t in self.visited:
            return
        self.visited.add(t)

        fqname, name = t.fullname(), t.name()

        self.includeHeaders.add("mozilla/ipc/SideVariant.h")
        self.maybeTypedef(_actorName(fqname, "Parent"), _actorName(name, "Parent"))
        self.maybeTypedef(_actorName(fqname, "Child"), _actorName(name, "Child"))

        self.forwardDeclStmts.extend(
            [
                _makeForwardDeclForActor(t.protocol, "parent"),
                Whitespace.NL,
                _makeForwardDeclForActor(t.protocol, "child"),
                Whitespace.NL,
            ]
        )

    def visitStructOrUnionType(self, su, defaultVisit):
        if su in self.visited or su == self.fortype:
            return
        self.visited.add(su)
        self.maybeTypedef(su.fullname(), su.name())

        # Mutually recursive fields in unions are behind indirection, so we only
        # need a forward decl, and don't need a full type declaration.
        if isinstance(self.fortype, UnionType) and self.fortype.mutuallyRecursiveWith(
            su
        ):
            self.forwardDeclStmts.append(_makeForwardDecl(su))
        else:
            self.fullDeclTypes.append(su)

        return defaultVisit(self, su)

    def visitStructType(self, t):
        return self.visitStructOrUnionType(t, TypeVisitor.visitStructType)

    def visitUnionType(self, t):
        return self.visitStructOrUnionType(t, TypeVisitor.visitUnionType)

    def visitArrayType(self, t):
        return TypeVisitor.visitArrayType(self, t)

    def visitMaybeType(self, m):
        return TypeVisitor.visitMaybeType(self, m)

    def visitShmemType(self, s):
        if s in self.visited:
            return
        self.visited.add(s)
        self.maybeTypedef("::mozilla::ipc::Shmem", "Shmem")

    def visitByteBufType(self, s):
        if s in self.visited:
            return
        self.visited.add(s)
        self.maybeTypedef("::mozilla::ipc::ByteBuf", "ByteBuf")

    def visitFDType(self, s):
        if s in self.visited:
            return
        self.visited.add(s)
        self.maybeTypedef("::mozilla::ipc::FileDescriptor", "FileDescriptor")

    def visitEndpointType(self, s):
        if s in self.visited:
            return
        self.visited.add(s)
        self.maybeTypedef("::mozilla::ipc::Endpoint", "Endpoint", ["FooSide"])
        self.visitActorType(s.actor)

    def visitManagedEndpointType(self, s):
        if s in self.visited:
            return
        self.visited.add(s)
        self.maybeTypedef(
            "::mozilla::ipc::ManagedEndpoint", "ManagedEndpoint", ["FooSide"]
        )
        self.visitActorType(s.actor)

    def visitUniquePtrType(self, s):
        if s in self.visited:
            return
        self.visited.add(s)

    def visitVoidType(self, v):
        assert 0

    def visitMessageType(self, v):
        assert 0

    def visitProtocolType(self, v):
        assert 0


def _fieldStaticAssertions(sd):
    staticasserts = []
    for (size, fields) in itertools.groupby(
        sd.fields_member_order(), lambda f: pod_size(f.ipdltype)
    ):
        if size == pod_size_sentinel:
            continue

        fields = list(fields)
        if len(fields) == 1:
            continue

        staticasserts.append(
            StmtCode(
                """
            static_assert(
                (offsetof(${struct}, ${last}) - offsetof(${struct}, ${first})) == ${expected},
                "Bad assumptions about field layout!");
            """,
                struct=sd.name,
                first=fields[0].memberVar(),
                last=fields[-1].memberVar(),
                expected=ExprLiteral.Int(size * (len(fields) - 1)),
            )
        )

    return staticasserts


def _generateCxxStruct(sd):
    """ """
    # compute all the typedefs and forward decls we need to make
    gettypedeps = _ComputeTypeDeps(sd.decl.type)
    for f in sd.fields:
        f.ipdltype.accept(gettypedeps)

    usingTypedefs = gettypedeps.usingTypedefs
    forwarddeclstmts = gettypedeps.forwardDeclStmts
    fulldecltypes = gettypedeps.fullDeclTypes

    struct = Class(sd.name, final=True)
    struct.addstmts([Label.PRIVATE] + usingTypedefs + [Whitespace.NL, Label.PUBLIC])

    constreftype = Type(sd.name, const=True, ref=True)

    # Struct()
    # We want the default constructor to be declared if it is available, but
    # some of our members may not be default-constructible. Silence the
    # warning which clang generates in that case.
    #
    # Members which need value initialization will be handled by wrapping
    # the member in a template type when declaring them.
    struct.addcode(
        """
        #ifdef __clang__
        #  pragma clang diagnostic push
        #  if __has_warning("-Wdefaulted-function-deleted")
        #    pragma clang diagnostic ignored "-Wdefaulted-function-deleted"
        #  endif
        #endif
        ${name}() = default;
        #ifdef __clang__
        #  pragma clang diagnostic pop
        #endif

        """,
        name=sd.name,
    )

    # If this is an empty struct (no fields), then the default ctor
    # and "create-with-fields" ctors are equivalent.
    if len(sd.fields):
        assert len(sd.fields) == len(sd.packed_field_order)

        # Struct(const field1& _f1, ...)
        valctor = ConstructorDefn(
            ConstructorDecl(
                sd.name,
                params=[
                    Decl(
                        f.forceMoveType()
                        if _cxxTypeNeedsMoveForData(f.ipdltype)
                        else f.constRefType(),
                        f.argVar().name,
                    )
                    for f in sd.fields_ipdl_order()
                ],
                force_inline=True,
            )
        )
        valctor.memberinits = []
        for f in sd.fields_member_order():
            arg = f.argVar()
            if _cxxTypeNeedsMoveForData(f.ipdltype):
                arg = ExprMove(arg)
            valctor.memberinits.append(ExprMemberInit(f.memberVar(), args=[arg]))

        struct.addstmts([valctor, Whitespace.NL])

        # If a constructor which moves each argument would be different from the
        # `const T&` version, also generate that constructor.
        if not all(
            _cxxTypeNeedsMoveForData(f.ipdltype) or not _cxxTypeCanMove(f.ipdltype)
            for f in sd.fields_ipdl_order()
        ):
            # Struct(field1&& _f1, ...)
            valmovector = ConstructorDefn(
                ConstructorDecl(
                    sd.name,
                    params=[
                        Decl(
                            f.forceMoveType()
                            if _cxxTypeCanMove(f.ipdltype)
                            else f.constRefType(),
                            f.argVar().name,
                        )
                        for f in sd.fields_ipdl_order()
                    ],
                    force_inline=True,
                )
            )

            valmovector.memberinits = []
            for f in sd.fields_member_order():
                arg = f.argVar()
                if _cxxTypeCanMove(f.ipdltype):
                    arg = ExprMove(arg)
                valmovector.memberinits.append(
                    ExprMemberInit(f.memberVar(), args=[arg])
                )

            struct.addstmts([valmovector, Whitespace.NL])

    # The default copy, move, and assignment constructors, and the default
    # destructor, will do the right thing.

    if "Comparable" in sd.attributes:
        # bool operator==(const Struct& _o)
        ovar = ExprVar("_o")
        opeqeq = MethodDefn(
            MethodDecl(
                "operator==",
                params=[Decl(constreftype, ovar.name)],
                ret=Type.BOOL,
                const=True,
            )
        )
        for f in sd.fields_ipdl_order():
            ifneq = StmtIf(
                ExprNot(
                    ExprBinary(
                        ExprCall(f.getMethod()), "==", ExprCall(f.getMethod(ovar))
                    )
                )
            )
            ifneq.addifstmt(StmtReturn.FALSE)
            opeqeq.addstmt(ifneq)
        opeqeq.addstmt(StmtReturn.TRUE)
        struct.addstmts([opeqeq, Whitespace.NL])

        # bool operator!=(const Struct& _o)
        opneq = MethodDefn(
            MethodDecl(
                "operator!=",
                params=[Decl(constreftype, ovar.name)],
                ret=Type.BOOL,
                const=True,
            )
        )
        opneq.addstmt(StmtReturn(ExprNot(ExprCall(ExprVar("operator=="), args=[ovar]))))
        struct.addstmts([opneq, Whitespace.NL])

    # field1& f1()
    # const field1& f1() const
    for f in sd.fields_ipdl_order():
        get = MethodDefn(
            MethodDecl(
                f.getMethod().name, params=[], ret=f.refType(), force_inline=True
            )
        )
        get.addstmt(StmtReturn(f.refExpr()))

        getconstdecl = deepcopy(get.decl)
        getconstdecl.ret = f.constRefType()
        getconstdecl.const = True
        getconst = MethodDefn(getconstdecl)
        getconst.addstmt(StmtReturn(f.constRefExpr()))

        struct.addstmts([get, getconst, Whitespace.NL])

    # private:
    struct.addstmt(Label.PRIVATE)

    # Static assertions to ensure our assumptions about field layout match
    # what the compiler is actually producing.  We define this as a member
    # function, rather than throwing the assertions in the constructor or
    # similar, because we don't want to evaluate the static assertions every
    # time the header file containing the structure is included.
    staticasserts = _fieldStaticAssertions(sd)
    if staticasserts:
        method = MethodDefn(
            MethodDecl("StaticAssertions", params=[], ret=Type.VOID, const=True)
        )
        method.addstmts(staticasserts)
        struct.addstmts([method])

    # members
    struct.addstmts(
        [
            StmtDecl(Decl(_effectiveMemberType(f), f.memberVar().name))
            for f in sd.fields_member_order()
        ]
    )

    return forwarddeclstmts, fulldecltypes, struct


def _effectiveMemberType(f):
    effective_type = f.bareType()
    # Structs must be copyable for backwards compatibility reasons, so we use
    # CopyableTArray<T> as their member type for arrays. This is not exposed
    # in the method signatures, these keep using nsTArray<T>, which is a base
    # class of CopyableTArray<T>.
    if effective_type.name == "nsTArray":
        effective_type.name = "CopyableTArray"
    return Type("::mozilla::ipc::IPDLStructMember", T=[effective_type])


# --------------------------------------------------


def _generateCxxUnion(ud):
    # This Union class basically consists of a type (enum) and a
    # union for storage.  The union can contain POD and non-POD
    # types.  Each type needs a copy/move ctor, assignment operators,
    # and dtor.
    #
    # Rather than templating this class and only providing
    # specializations for the types we support, which is slightly
    # "unsafe" in that C++ code can add additional specializations
    # without the IPDL compiler's knowledge, we instead explicitly
    # implement non-templated methods for each supported type.
    #
    # The one complication that arises is that C++, for arcane
    # reasons, does not allow the placement destructor of a
    # builtin type, like int, to be directly invoked.  So we need
    # to hack around this by internally typedef'ing all
    # constituent types.  Sigh.
    #
    # So, for each type, this "Union" class needs:
    # (private)
    #  - entry in the type enum
    #  - entry in the storage union
    #  - [type]ptr() method to get a type* from the underlying union
    #  - same as above to get a const type*
    #  - typedef to hack around placement delete limitations
    # (public)
    #  - placement delete case for dtor
    #  - copy ctor
    #  - move ctor
    #  - case in generic copy ctor
    #  - copy operator= impl
    #  - move operator= impl
    #  - case in generic operator=
    #  - operator [type&]
    #  - operator [const type&] const
    #  - [type&] get_[type]()
    #  - [const type&] get_[type]() const
    #
    cls = Class(ud.name, final=True)
    # const Union&, i.e., Union type with inparam semantics
    inClsType = Type(ud.name, const=True, ref=True)
    refClsType = Type(ud.name, ref=True)
    rvalueRefClsType = Type(ud.name, rvalref=True)
    typetype = Type("Type")
    valuetype = Type("Value")
    mtypevar = ExprVar("mType")
    mvaluevar = ExprVar("mValue")
    maybedtorvar = ExprVar("MaybeDestroy")
    assertsanityvar = ExprVar("AssertSanity")
    tnonevar = ExprVar("T__None")
    tlastvar = ExprVar("T__Last")

    def callAssertSanity(uvar=None, expectTypeVar=None):
        func = assertsanityvar
        args = []
        if uvar is not None:
            func = ExprSelect(uvar, ".", assertsanityvar.name)
        if expectTypeVar is not None:
            args.append(expectTypeVar)
        return ExprCall(func, args=args)

    def maybeDestroy():
        return StmtExpr(ExprCall(maybedtorvar))

    # compute all the typedefs and forward decls we need to make
    gettypedeps = _ComputeTypeDeps(ud.decl.type)
    for c in ud.components:
        c.ipdltype.accept(gettypedeps)

    usingTypedefs = gettypedeps.usingTypedefs
    forwarddeclstmts = gettypedeps.forwardDeclStmts
    fulldecltypes = gettypedeps.fullDeclTypes

    # the |Type| enum, used to switch on the discunion's real type
    cls.addstmt(Label.PUBLIC)
    typeenum = TypeEnum(typetype.name)
    typeenum.addId(tnonevar.name, 0)
    firstid = ud.components[0].enum()
    typeenum.addId(firstid, 1)
    for c in ud.components[1:]:
        typeenum.addId(c.enum())
    typeenum.addId(tlastvar.name, ud.components[-1].enum())
    cls.addstmts([StmtDecl(Decl(typeenum, "")), Whitespace.NL])

    cls.addstmt(Label.PRIVATE)
    cls.addstmts(
        usingTypedefs
        # hacky typedef's that allow placement dtors of builtins
        + [Typedef(c.internalType(), c.typedef()) for c in ud.components]
    )
    cls.addstmt(Whitespace.NL)

    # the C++ union the discunion use for storage
    valueunion = TypeUnion(valuetype.name)
    for c in ud.components:
        valueunion.addComponent(c.unionType(), c.name)
    cls.addstmts([StmtDecl(Decl(valueunion, "")), Whitespace.NL])

    # for each constituent type T, add private accessors that
    # return a pointer to the Value union storage casted to |T*|
    # and |const T*|
    for c in ud.components:
        getptr = MethodDefn(
            MethodDecl(
                c.getPtrName(), params=[], ret=c.ptrToInternalType(), force_inline=True
            )
        )
        getptr.addstmt(StmtReturn(c.ptrToSelfExpr()))

        getptrconst = MethodDefn(
            MethodDecl(
                c.getConstPtrName(),
                params=[],
                ret=c.constPtrToType(),
                const=True,
                force_inline=True,
            )
        )
        getptrconst.addstmt(StmtReturn(c.constptrToSelfExpr()))

        cls.addstmts([getptr, getptrconst])
    cls.addstmt(Whitespace.NL)

    # add a helper method that invokes the placement dtor on the
    # current underlying value, only if |aNewType| is different
    # than the current type, and returns true if the underlying
    # value needs to be re-constructed
    maybedtor = MethodDefn(MethodDecl(maybedtorvar.name, ret=Type.VOID))
    # wasn't /actually/ dtor'd, but it needs to be re-constructed
    ifnone = StmtIf(ExprBinary(mtypevar, "==", tnonevar))
    ifnone.addifstmt(StmtReturn())
    # need to destroy.  switch on underlying type
    dtorswitch = StmtSwitch(mtypevar)
    for c in ud.components:
        dtorswitch.addcase(
            CaseLabel(c.enum()), StmtBlock([StmtExpr(c.callDtor()), StmtBreak()])
        )
    dtorswitch.addcase(
        DefaultLabel(), StmtBlock([_logicError("not reached"), StmtBreak()])
    )
    maybedtor.addstmts([ifnone, dtorswitch])
    cls.addstmts([maybedtor, Whitespace.NL])

    # add helper methods that ensure the discunion has a
    # valid type
    sanity = MethodDefn(
        MethodDecl(assertsanityvar.name, ret=Type.VOID, const=True, force_inline=True)
    )
    sanity.addstmts(
        [
            _abortIfFalse(ExprBinary(tnonevar, "<=", mtypevar), "invalid type tag"),
            _abortIfFalse(ExprBinary(mtypevar, "<=", tlastvar), "invalid type tag"),
        ]
    )
    cls.addstmt(sanity)

    atypevar = ExprVar("aType")
    sanity2 = MethodDefn(
        MethodDecl(
            assertsanityvar.name,
            params=[Decl(typetype, atypevar.name)],
            ret=Type.VOID,
            const=True,
            force_inline=True,
        )
    )
    sanity2.addstmts(
        [
            StmtExpr(ExprCall(assertsanityvar)),
            _abortIfFalse(ExprBinary(mtypevar, "==", atypevar), "unexpected type tag"),
        ]
    )
    cls.addstmts([sanity2, Whitespace.NL])

    # ---- begin public methods -----

    # Union() default ctor
    cls.addstmts(
        [
            Label.PUBLIC,
            ConstructorDefn(
                ConstructorDecl(ud.name, force_inline=True),
                memberinits=[ExprMemberInit(mtypevar, [tnonevar])],
            ),
            Whitespace.NL,
        ]
    )

    # Union(const T&) copy & Union(T&&) move ctors
    othervar = ExprVar("aOther")
    for c in ud.components:
        if not _cxxTypeNeedsMoveForData(c.ipdltype):
            copyctor = ConstructorDefn(
                ConstructorDecl(ud.name, params=[Decl(c.constRefType(), othervar.name)])
            )
            copyctor.addstmts(
                [
                    StmtExpr(c.callCtor(othervar)),
                    StmtExpr(ExprAssn(mtypevar, c.enumvar())),
                ]
            )
            cls.addstmts([copyctor, Whitespace.NL])

        if not _cxxTypeCanMove(c.ipdltype):
            continue
        movector = ConstructorDefn(
            ConstructorDecl(ud.name, params=[Decl(c.forceMoveType(), othervar.name)])
        )
        movector.addstmts(
            [
                StmtExpr(c.callCtor(ExprMove(othervar))),
                StmtExpr(ExprAssn(mtypevar, c.enumvar())),
            ]
        )
        cls.addstmts([movector, Whitespace.NL])

    unionNeedsMove = any(_cxxTypeNeedsMoveForData(c.ipdltype) for c in ud.components)

    # Union(const Union&) copy ctor
    if not unionNeedsMove:
        copyctor = ConstructorDefn(
            ConstructorDecl(ud.name, params=[Decl(inClsType, othervar.name)])
        )
        othertype = ud.callType(othervar)
        copyswitch = StmtSwitch(othertype)
        for c in ud.components:
            copyswitch.addcase(
                CaseLabel(c.enum()),
                StmtBlock(
                    [
                        StmtExpr(
                            c.callCtor(
                                ExprCall(
                                    ExprSelect(othervar, ".", c.getConstTypeName())
                                )
                            )
                        ),
                        StmtBreak(),
                    ]
                ),
            )
        copyswitch.addcase(CaseLabel(tnonevar.name), StmtBlock([StmtBreak()]))
        copyswitch.addcase(
            DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn()])
        )
        copyctor.addstmts(
            [
                StmtExpr(callAssertSanity(uvar=othervar)),
                copyswitch,
                StmtExpr(ExprAssn(mtypevar, othertype)),
            ]
        )
        cls.addstmts([copyctor, Whitespace.NL])

    # Union(Union&&) move ctor
    movector = ConstructorDefn(
        ConstructorDecl(ud.name, params=[Decl(rvalueRefClsType, othervar.name)])
    )
    othertypevar = ExprVar("t")
    moveswitch = StmtSwitch(othertypevar)
    for c in ud.components:
        case = StmtBlock()
        if c.recursive:
            # This is sound as we set othervar.mTypeVar to T__None after the
            # switch. The pointer in the union will be left dangling.
            case.addstmts(
                [
                    # ptr_C() = other.ptr_C()
                    StmtExpr(
                        ExprAssn(
                            c.callGetPtr(),
                            ExprCall(
                                ExprSelect(othervar, ".", ExprVar(c.getPtrName()))
                            ),
                        )
                    )
                ]
            )
        else:
            case.addstmts(
                [
                    # new ... (Move(other.get_C()))
                    StmtExpr(
                        c.callCtor(
                            ExprMove(
                                ExprCall(ExprSelect(othervar, ".", c.getTypeName()))
                            )
                        )
                    ),
                    # other.MaybeDestroy(T__None)
                    StmtExpr(ExprCall(ExprSelect(othervar, ".", maybedtorvar))),
                ]
            )
        case.addstmts([StmtBreak()])
        moveswitch.addcase(CaseLabel(c.enum()), case)
    moveswitch.addcase(CaseLabel(tnonevar.name), StmtBlock([StmtBreak()]))
    moveswitch.addcase(
        DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn()])
    )
    movector.addstmts(
        [
            StmtExpr(callAssertSanity(uvar=othervar)),
            StmtDecl(Decl(typetype, othertypevar.name), init=ud.callType(othervar)),
            moveswitch,
            StmtExpr(ExprAssn(ExprSelect(othervar, ".", mtypevar), tnonevar)),
            StmtExpr(ExprAssn(mtypevar, othertypevar)),
        ]
    )
    cls.addstmts([movector, Whitespace.NL])

    # ~Union()
    dtor = DestructorDefn(DestructorDecl(ud.name))
    dtor.addstmt(maybeDestroy())
    cls.addstmts([dtor, Whitespace.NL])

    # type()
    typemeth = MethodDefn(
        MethodDecl("type", ret=typetype, const=True, force_inline=True)
    )
    typemeth.addstmt(StmtReturn(mtypevar))
    cls.addstmts([typemeth, Whitespace.NL])

    # Union& operator= methods
    rhsvar = ExprVar("aRhs")
    for c in ud.components:

        def opeqBody(rhs):
            return [
                # might need to placement-delete old value first
                maybeDestroy(),
                StmtExpr(c.callCtor(rhs)),
                StmtExpr(ExprAssn(mtypevar, c.enumvar())),
                StmtReturn(ExprDeref(ExprVar.THIS)),
            ]

        if not _cxxTypeNeedsMoveForData(c.ipdltype):
            # Union& operator=(const T&)
            opeq = MethodDefn(
                MethodDecl(
                    "operator=",
                    params=[Decl(c.constRefType(), rhsvar.name)],
                    ret=refClsType,
                )
            )
            opeq.addstmts(opeqBody(rhsvar))
            cls.addstmts([opeq, Whitespace.NL])

        # Union& operator=(T&&)
        if not _cxxTypeCanMove(c.ipdltype):
            continue

        opeq = MethodDefn(
            MethodDecl(
                "operator=",
                params=[Decl(c.forceMoveType(), rhsvar.name)],
                ret=refClsType,
            )
        )
        opeq.addstmts(opeqBody(ExprMove(rhsvar)))
        cls.addstmts([opeq, Whitespace.NL])

    # Union& operator=(const Union&)
    if not unionNeedsMove:
        opeq = MethodDefn(
            MethodDecl(
                "operator=", params=[Decl(inClsType, rhsvar.name)], ret=refClsType
            )
        )
        rhstypevar = ExprVar("t")
        opeqswitch = StmtSwitch(rhstypevar)
        for c in ud.components:
            case = StmtBlock()
            case.addstmts(
                [
                    maybeDestroy(),
                    StmtExpr(
                        c.callCtor(
                            ExprCall(ExprSelect(rhsvar, ".", c.getConstTypeName()))
                        )
                    ),
                    StmtBreak(),
                ]
            )
            opeqswitch.addcase(CaseLabel(c.enum()), case)
        opeqswitch.addcase(
            CaseLabel(tnonevar.name),
            StmtBlock([maybeDestroy(), StmtBreak()]),
        )
        opeqswitch.addcase(
            DefaultLabel(), StmtBlock([_logicError("unreached"), StmtBreak()])
        )
        opeq.addstmts(
            [
                StmtExpr(callAssertSanity(uvar=rhsvar)),
                StmtDecl(Decl(typetype, rhstypevar.name), init=ud.callType(rhsvar)),
                opeqswitch,
                StmtExpr(ExprAssn(mtypevar, rhstypevar)),
                StmtReturn(ExprDeref(ExprVar.THIS)),
            ]
        )
        cls.addstmts([opeq, Whitespace.NL])

    # Union& operator=(Union&&)
    opeq = MethodDefn(
        MethodDecl(
            "operator=", params=[Decl(rvalueRefClsType, rhsvar.name)], ret=refClsType
        )
    )
    rhstypevar = ExprVar("t")
    opeqswitch = StmtSwitch(rhstypevar)
    for c in ud.components:
        case = StmtBlock()
        if c.recursive:
            case.addstmts(
                [
                    maybeDestroy(),
                    StmtExpr(
                        ExprAssn(
                            c.callGetPtr(),
                            ExprCall(ExprSelect(rhsvar, ".", ExprVar(c.getPtrName()))),
                        )
                    ),
                ]
            )
        else:
            case.addstmts(
                [
                    maybeDestroy(),
                    StmtExpr(
                        c.callCtor(
                            ExprMove(ExprCall(ExprSelect(rhsvar, ".", c.getTypeName())))
                        )
                    ),
                    # other.MaybeDestroy()
                    StmtExpr(ExprCall(ExprSelect(rhsvar, ".", maybedtorvar))),
                ]
            )
        case.addstmts([StmtBreak()])
        opeqswitch.addcase(CaseLabel(c.enum()), case)
    opeqswitch.addcase(
        CaseLabel(tnonevar.name),
        StmtBlock([maybeDestroy(), StmtBreak()]),
    )
    opeqswitch.addcase(
        DefaultLabel(), StmtBlock([_logicError("unreached"), StmtBreak()])
    )
    opeq.addstmts(
        [
            StmtExpr(callAssertSanity(uvar=rhsvar)),
            StmtDecl(Decl(typetype, rhstypevar.name), init=ud.callType(rhsvar)),
            opeqswitch,
            StmtExpr(ExprAssn(ExprSelect(rhsvar, ".", mtypevar), tnonevar)),
            StmtExpr(ExprAssn(mtypevar, rhstypevar)),
            StmtReturn(ExprDeref(ExprVar.THIS)),
        ]
    )
    cls.addstmts([opeq, Whitespace.NL])

    if "Comparable" in ud.attributes:
        # bool operator==(const T&)
        for c in ud.components:
            opeqeq = MethodDefn(
                MethodDecl(
                    "operator==",
                    params=[Decl(c.constRefType(), rhsvar.name)],
                    ret=Type.BOOL,
                    const=True,
                )
            )
            opeqeq.addstmt(
                StmtReturn(ExprBinary(ExprCall(ExprVar(c.getTypeName())), "==", rhsvar))
            )
            cls.addstmts([opeqeq, Whitespace.NL])

        # bool operator==(const Union&)
        opeqeq = MethodDefn(
            MethodDecl(
                "operator==",
                params=[Decl(inClsType, rhsvar.name)],
                ret=Type.BOOL,
                const=True,
            )
        )
        iftypesmismatch = StmtIf(ExprBinary(ud.callType(), "!=", ud.callType(rhsvar)))
        iftypesmismatch.addifstmt(StmtReturn.FALSE)
        opeqeq.addstmts([iftypesmismatch, Whitespace.NL])

        opeqeqswitch = StmtSwitch(ud.callType())
        for c in ud.components:
            case = StmtBlock()
            case.addstmt(
                StmtReturn(
                    ExprBinary(
                        ExprCall(ExprVar(c.getTypeName())),
                        "==",
                        ExprCall(ExprSelect(rhsvar, ".", c.getTypeName())),
                    )
                )
            )
            opeqeqswitch.addcase(CaseLabel(c.enum()), case)
        opeqeqswitch.addcase(
            DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn.FALSE])
        )
        opeqeq.addstmt(opeqeqswitch)

        cls.addstmts([opeqeq, Whitespace.NL])

    # accessors for each type: operator T&, operator const T&,
    # T& get(), const T& get()
    for c in ud.components:
        getValueVar = ExprVar(c.getTypeName())
        getConstValueVar = ExprVar(c.getConstTypeName())

        getvalue = MethodDefn(
            MethodDecl(getValueVar.name, ret=c.refType(), force_inline=True)
        )
        getvalue.addstmts(
            [
                StmtExpr(callAssertSanity(expectTypeVar=c.enumvar())),
                StmtReturn(ExprDeref(c.callGetPtr())),
            ]
        )

        getconstvalue = MethodDefn(
            MethodDecl(
                getConstValueVar.name,
                ret=c.constRefType(),
                const=True,
                force_inline=True,
            )
        )
        getconstvalue.addstmts(
            [
                StmtExpr(callAssertSanity(expectTypeVar=c.enumvar())),
                StmtReturn(c.getConstValue()),
            ]
        )

        cls.addstmts([getvalue, getconstvalue])

        optype = MethodDefn(MethodDecl("", typeop=c.refType(), force_inline=True))
        optype.addstmt(StmtReturn(ExprCall(getValueVar)))
        opconsttype = MethodDefn(
            MethodDecl("", const=True, typeop=c.constRefType(), force_inline=True)
        )
        opconsttype.addstmt(StmtReturn(ExprCall(getConstValueVar)))

        cls.addstmts([optype, opconsttype, Whitespace.NL])
    # private vars
    cls.addstmts(
        [
            Label.PRIVATE,
            StmtDecl(Decl(valuetype, mvaluevar.name)),
            StmtDecl(Decl(typetype, mtypevar.name)),
        ]
    )

    return forwarddeclstmts, fulldecltypes, cls


# -----------------------------------------------------------------------------


class _FindFriends(ipdl.ast.Visitor):
    def __init__(self):
        self.mytype = None  # ProtocolType
        self.vtype = None  # ProtocolType
        self.friends = set()  # set<ProtocolType>

    def findFriends(self, ptype):
        self.mytype = ptype
        for toplvl in ptype.toplevels():
            self.walkDownTheProtocolTree(toplvl)
        return self.friends

    # TODO could make this into a _iterProtocolTreeHelper ...
    def walkDownTheProtocolTree(self, ptype):
        if ptype != self.mytype:
            # don't want to |friend| ourself!
            self.visit(ptype)
        for mtype in ptype.manages:
            if mtype is not ptype:
                self.walkDownTheProtocolTree(mtype)

    def visit(self, ptype):
        # |vtype| is the type currently being visited
        savedptype = self.vtype
        self.vtype = ptype
        ptype._ast.accept(self)
        self.vtype = savedptype

    def visitMessageDecl(self, md):
        for it in self.iterActorParams(md):
            if it.protocol == self.mytype:
                self.friends.add(self.vtype)

    def iterActorParams(self, md):
        for param in md.inParams:
            for actor in ipdl.type.iteractortypes(param.type):
                yield actor
        for ret in md.outParams:
            for actor in ipdl.type.iteractortypes(ret.type):
                yield actor


class _GenerateProtocolActorCode(ipdl.ast.Visitor):
    def __init__(self, myside):
        self.side = myside  # "parent" or "child"
        self.prettyside = myside.title()
        self.clsname = None
        self.protocol = None
        self.hdrfile = None
        self.cppfile = None
        self.ns = None
        self.cls = None
        self.protocolCxxIncludes = []
        self.actorForwardDecls = []
        self.usingDecls = []
        self.externalIncludes = set()
        self.nonForwardDeclaredHeaders = set()
        self.typedefSet = set(
            [
                Typedef(Type("mozilla::ipc::ActorHandle"), "ActorHandle"),
                Typedef(Type("base::ProcessId"), "ProcessId"),
                Typedef(Type("mozilla::ipc::ProtocolId"), "ProtocolId"),
                Typedef(Type("mozilla::ipc::Endpoint"), "Endpoint", ["FooSide"]),
                Typedef(
                    Type("mozilla::ipc::ManagedEndpoint"),
                    "ManagedEndpoint",
                    ["FooSide"],
                ),
                Typedef(Type("mozilla::UniquePtr"), "UniquePtr", ["T"]),
                Typedef(
                    Type("mozilla::ipc::ResponseRejectReason"), "ResponseRejectReason"
                ),
            ]
        )

    def lower(self, tu, clsname, cxxHeaderFile, cxxFile):
        self.clsname = clsname
        self.hdrfile = cxxHeaderFile
        self.cppfile = cxxFile
        tu.accept(self)

    def standardTypedefs(self):
        return [
            Typedef(Type("mozilla::ipc::IProtocol"), "IProtocol"),
            Typedef(Type("IPC::Message"), "Message"),
            Typedef(Type("base::ProcessHandle"), "ProcessHandle"),
            Typedef(Type("mozilla::ipc::MessageChannel"), "MessageChannel"),
            Typedef(Type("mozilla::ipc::SharedMemory"), "SharedMemory"),
        ]

    def visitTranslationUnit(self, tu):
        self.protocol = tu.protocol

        hf = self.hdrfile
        cf = self.cppfile

        # make the C++ header
        hf.addthings(
            [_DISCLAIMER]
            + _includeGuardStart(hf)
            + [
                Whitespace.NL,
                CppDirective("include", '"' + _protocolHeaderName(tu.protocol) + '.h"'),
            ]
        )

        for inc in tu.includes:
            inc.accept(self)
        for inc in tu.cxxIncludes:
            inc.accept(self)

        for using in tu.builtinUsing:
            using.accept(self)
        for using in tu.using:
            using.accept(self)
        for su in tu.structsAndUnions:
            su.accept(self)

        # this generates the actor's full impl in self.cls
        tu.protocol.accept(self)

        clsdecl, clsdefn = _splitClassDeclDefn(self.cls)

        # XXX damn C++ ... return types in the method defn aren't in
        # class scope
        for stmt in clsdefn.stmts:
            if isinstance(stmt, MethodDefn):
                if stmt.decl.ret and stmt.decl.ret.name == "Result":
                    stmt.decl.ret.name = clsdecl.name + "::" + stmt.decl.ret.name

        def setToIncludes(s):
            return [CppDirective("include", '"%s"' % i) for i in sorted(iter(s))]

        def makeNamespace(p, file):
            if 0 == len(p.namespaces):
                return file
            ns = Namespace(p.namespaces[-1].name)
            outerns = _putInNamespaces(ns, p.namespaces[:-1])
            file.addthing(outerns)
            return ns

        if len(self.nonForwardDeclaredHeaders) != 0:
            self.hdrfile.addthings(
                [
                    Whitespace("// Headers for things that cannot be forward declared"),
                    Whitespace.NL,
                ]
                + setToIncludes(self.nonForwardDeclaredHeaders)
                + [Whitespace.NL]
            )
        self.hdrfile.addthings(self.actorForwardDecls)
        self.hdrfile.addthings(self.usingDecls)

        hdrns = makeNamespace(self.protocol, self.hdrfile)
        hdrns.addstmts(
            [Whitespace.NL, Whitespace.NL, clsdecl, Whitespace.NL, Whitespace.NL]
        )

        actortype = ActorType(tu.protocol.decl.type)
        traitsdecl, traitsdefn = _ParamTraits.actorPickling(actortype, self.side)

        self.hdrfile.addthings([traitsdecl, Whitespace.NL] + _includeGuardEnd(hf))

        # If the implementation type is not overridden, add an implicit import
        # for the default implementation header file. Explicit implementation
        # types will specify their headers manually with `include`.
        if self.protocol.implAttribute(self.side) is None:
            assert self.protocol.name.startswith("P")
            self.externalIncludes.add(
                "".join(n.name + "/" for n in self.protocol.namespaces)
                + self.protocol.name[1:]
                + self.side.capitalize()
                + ".h"
            )

        # make the .cpp file
        cf.addthings(
            [
                _DISCLAIMER,
                Whitespace.NL,
                CppDirective(
                    "include",
                    '"' + _protocolHeaderName(self.protocol, self.side) + '.h"',
                ),
            ]
            + setToIncludes(self.externalIncludes)
        )

        cf.addthings(
            (
                [Whitespace.NL]
                + [
                    CppDirective("include", '"%s.h"' % (inc))
                    for inc in self.protocolCxxIncludes
                ]
                + [Whitespace.NL]
                + [
                    CppDirective("include", '"%s"' % filename)
                    for filename in ipdl.builtin.CppIncludes
                ]
                + [Whitespace.NL]
            )
        )

        cppns = makeNamespace(self.protocol, cf)
        cppns.addstmts(
            [Whitespace.NL, Whitespace.NL, clsdefn, Whitespace.NL, Whitespace.NL]
        )

        cf.addthing(traitsdefn)

    def visitUsingStmt(self, using):
        if using.decl.fullname is not None:
            self.typedefSet.add(
                Typedef(Type(using.decl.fullname), using.decl.shortname)
            )

        if using.header is None:
            return

        if using.canBeForwardDeclared():
            spec = using.type

            self.usingDecls.extend(
                [
                    _makeForwardDeclForQClass(
                        spec.baseid,
                        spec.quals,
                        cls=using.isClass(),
                        struct=using.isStruct(),
                    ),
                    Whitespace.NL,
                ]
            )
            self.externalIncludes.add(using.header)
        else:
            self.nonForwardDeclaredHeaders.add(using.header)

    def visitCxxInclude(self, inc):
        self.externalIncludes.add(inc.file)

    def visitInclude(self, inc):
        if inc.tu.filetype == "header":
            # Including a header will declare any globals defined by "using"
            # statements into our scope. To serialize these, we also may need
            # cxx include statements, so visit them as well.
            for cxxinc in inc.tu.cxxIncludes:
                cxxinc.accept(self)
            for using in inc.tu.using:
                using.accept(self)
            for su in inc.tu.structsAndUnions:
                su.accept(self)
        else:
            # Includes for protocols only include types explicitly exported by
            # those protocols.
            ip = inc.tu.protocol
            if ip == self.protocol:
                return

            self.actorForwardDecls.extend(
                [
                    _makeForwardDeclForActor(ip.decl.type, self.side),
                    _makeForwardDeclForActor(ip.decl.type, _otherSide(self.side)),
                    Whitespace.NL,
                ]
            )
            self.protocolCxxIncludes.append(_protocolHeaderName(ip, self.side))

            if ip.decl.fullname is not None:
                self.typedefSet.add(
                    Typedef(
                        Type(_actorName(ip.decl.fullname, self.side.title())),
                        _actorName(ip.decl.shortname, self.side.title()),
                    )
                )

                self.typedefSet.add(
                    Typedef(
                        Type(
                            _actorName(ip.decl.fullname, _otherSide(self.side).title())
                        ),
                        _actorName(ip.decl.shortname, _otherSide(self.side).title()),
                    )
                )

    def visitStructDecl(self, sd):
        if sd.decl.fullname is not None:
            self.typedefSet.add(Typedef(Type(sd.fqClassName()), sd.name))

    def visitUnionDecl(self, ud):
        if ud.decl.fullname is not None:
            self.typedefSet.add(Typedef(Type(ud.fqClassName()), ud.name))

    def visitProtocol(self, p):
        self.hdrfile.addcode(
            """
            #ifdef DEBUG
            #include "prenv.h"
            #endif  // DEBUG

            #include "mozilla/Tainting.h"
            #include "mozilla/ipc/MessageChannel.h"
            #include "mozilla/ipc/ProtocolUtils.h"
            """
        )

        self.protocol = p
        ptype = p.decl.type
        toplevel = p.decl.type.toplevel()

        hasAsyncReturns = False
        for md in p.messageDecls:
            if md.hasAsyncReturns():
                hasAsyncReturns = True
                break

        inherits = []
        if ptype.isToplevel():
            inherits.append(Inherit(p.openedProtocolInterfaceType(), viz="public"))
        else:
            inherits.append(Inherit(p.managerInterfaceType(), viz="public"))

        if ptype.isToplevel() and self.side == "parent":
            self.hdrfile.addthings(
                [_makeForwardDeclForQClass("nsIFile", []), Whitespace.NL]
            )

        self.cls = Class(self.clsname, inherits=inherits, abstract=True)

        self.cls.addstmt(Label.PRIVATE)
        friends = _FindFriends().findFriends(ptype)
        if ptype.isManaged():
            friends.update(ptype.managers)

        # |friend| managed actors so that they can call our Dealloc*()
        friends.update(ptype.manages)

        # don't friend ourself if we're a self-managed protocol
        friends.discard(ptype)

        for friend in sorted(friends, key=lambda f: f.fullname()):
            self.actorForwardDecls.extend(
                [_makeForwardDeclForActor(friend, self.prettyside), Whitespace.NL]
            )
            self.cls.addstmt(
                FriendClassDecl(_actorName(friend.fullname(), self.prettyside))
            )

        self.cls.addstmt(Label.PROTECTED)
        for typedef in sorted(self.typedefSet):
            self.cls.addstmt(typedef)

        self.cls.addstmt(Whitespace.NL)

        if hasAsyncReturns:
            self.cls.addstmt(Label.PUBLIC)
            for md in p.messageDecls:
                if self.sendsMessage(md) and md.hasAsyncReturns():
                    self.cls.addstmt(
                        Typedef(_makePromise(md.returns, self.side), md.promiseName())
                    )
                if self.receivesMessage(md) and md.hasAsyncReturns():
                    self.cls.addstmt(
                        Typedef(_makeResolver(md.returns, self.side), md.resolverName())
                    )
            self.cls.addstmt(Whitespace.NL)

        self.cls.addstmt(Label.PROTECTED)
        # interface methods that the concrete subclass has to impl
        for md in p.messageDecls:
            isctor, isdtor = md.decl.type.isCtor(), md.decl.type.isDtor()

            if self.receivesMessage(md):
                # generate Recv/Answer* interface
                implicit = not isdtor
                returnsems = "resolver" if md.decl.type.isAsync() else "out"
                recvDecl = MethodDecl(
                    md.recvMethod(),
                    params=md.makeCxxParams(
                        paramsems="move",
                        returnsems=returnsems,
                        side=self.side,
                        implicit=implicit,
                        direction="recv",
                    ),
                    ret=Type("mozilla::ipc::IPCResult"),
                    methodspec=MethodSpec.VIRTUAL,
                )

                # These method implementations cause problems when trying to
                # override them with different types in a direct call class.
                #
                # For the `isdtor` case there's a simple solution: it doesn't
                # make much sense to specify arguments and then completely
                # ignore them, and the no-arg case isn't a problem for
                # overriding.
                if isctor or (isdtor and not md.inParams):
                    defaultRecv = MethodDefn(recvDecl)
                    defaultRecv.addcode("return IPC_OK();\n")
                    self.cls.addstmt(defaultRecv)
                elif self.protocol.implAttribute(self.side) == "virtual":
                    # If we're using virtual calls, we need the methods to be
                    # declared on the base class.
                    recvDecl.methodspec = MethodSpec.PURE
                    self.cls.addstmt(StmtDecl(recvDecl))

        # If we're using virtual calls, we need the methods to be declared on
        # the base class.
        if self.protocol.implAttribute(self.side) == "virtual":
            for md in p.messageDecls:
                managed = md.decl.type.constructedType()
                if not ptype.isManagerOf(managed) or md.decl.type.isDtor():
                    continue

                # add the Alloc interface for managed actors
                actortype = md.actorDecl().bareType(self.side)

                if managed.isRefcounted():
                    if not self.receivesMessage(md):
                        continue

                    actortype.ptr = False
                    actortype = _alreadyaddrefed(actortype)

                self.cls.addstmt(
                    StmtDecl(
                        MethodDecl(
                            _allocMethod(managed, self.side),
                            params=md.makeCxxParams(
                                side=self.side, implicit=False, direction="recv"
                            ),
                            ret=actortype,
                            methodspec=MethodSpec.PURE,
                        )
                    )
                )

            # add the Dealloc interface for all managed non-refcounted actors,
            # even without ctors. This is useful for protocols which use
            # ManagedEndpoint for construction.
            for managed in ptype.manages:
                if managed.isRefcounted():
                    continue

                self.cls.addstmt(
                    StmtDecl(
                        MethodDecl(
                            _deallocMethod(managed, self.side),
                            params=[
                                Decl(p.managedCxxType(managed, self.side), "aActor")
                            ],
                            ret=Type.BOOL,
                            methodspec=MethodSpec.PURE,
                        )
                    )
                )

        if ptype.isToplevel():
            # void ProcessingError(code); default to no-op
            processingerror = MethodDefn(
                MethodDecl(
                    p.processingErrorVar().name,
                    params=[
                        Param(_Result.Type(), "aCode"),
                        Param(Type("char", const=True, ptr=True), "aReason"),
                    ],
                    methodspec=MethodSpec.OVERRIDE,
                )
            )

            # bool ShouldContinueFromReplyTimeout(); default to |true|
            shouldcontinue = MethodDefn(
                MethodDecl(
                    p.shouldContinueFromTimeoutVar().name,
                    ret=Type.BOOL,
                    methodspec=MethodSpec.OVERRIDE,
                )
            )
            shouldcontinue.addcode("return true;\n")

            self.cls.addstmts(
                [
                    processingerror,
                    shouldcontinue,
                    Whitespace.NL,
                ]
            )

        self.cls.addstmts(([Label.PUBLIC] + self.standardTypedefs() + [Whitespace.NL]))

        self.cls.addstmt(Label.PUBLIC)
        # Actor()
        ctor = ConstructorDefn(ConstructorDecl(self.clsname))
        side = ExprVar("mozilla::ipc::" + self.side.title() + "Side")
        if ptype.isToplevel():
            name = ExprLiteral.String(_actorName(p.name, self.side))
            ctor.memberinits = [
                ExprMemberInit(
                    ExprVar("mozilla::ipc::IToplevelProtocol"),
                    [name, _protocolId(ptype), side],
                )
            ]
        else:
            ctor.memberinits = [
                ExprMemberInit(
                    ExprVar("mozilla::ipc::IProtocol"), [_protocolId(ptype), side]
                )
            ]

        ctor.addcode("MOZ_COUNT_CTOR(${clsname});\n", clsname=self.clsname)
        self.cls.addstmts([ctor, Whitespace.NL])

        # ~Actor()
        dtor = DestructorDefn(
            DestructorDecl(self.clsname, methodspec=MethodSpec.VIRTUAL)
        )
        dtor.addcode("MOZ_COUNT_DTOR(${clsname});\n", clsname=self.clsname)

        self.cls.addstmts([dtor, Whitespace.NL])

        if ptype.isRefcounted():
            if not ptype.isToplevel():
                self.cls.addcode(
                    """
                    NS_INLINE_DECL_PURE_VIRTUAL_REFCOUNTING
                    """
                )
            self.cls.addstmt(Label.PROTECTED)
            self.cls.addcode(
                """
                void ActorAlloc() final { AddRef(); }
                void ActorDealloc() final { Release(); }
                """
            )

        self.cls.addstmt(Label.PUBLIC)
        if ptype.hasOtherPid():
            otherpidmeth = MethodDefn(
                MethodDecl("OtherPid", ret=Type("::base::ProcessId"), const=True)
            )
            otherpidmeth.addcode(
                """
                ::base::ProcessId pid =
                    ::mozilla::ipc::IProtocol::ToplevelProtocol()->OtherPidMaybeInvalid();
                MOZ_RELEASE_ASSERT(pid != ::base::kInvalidProcessId);
                return pid;
                """
            )
            self.cls.addstmts([otherpidmeth, Whitespace.NL])

        if not ptype.isToplevel():
            if 1 == len(p.managers):
                # manager() const
                managertype = p.managerActorType(self.side, ptr=True)
                managermeth = MethodDefn(
                    MethodDecl("Manager", ret=managertype, const=True)
                )
                managermeth.addcode(
                    """
                    return static_cast<${type}>(IProtocol::Manager());
                    """,
                    type=managertype,
                )

                self.cls.addstmts([managermeth, Whitespace.NL])

        def actorFromIter(itervar):
            return ExprCode("${iter}.Get()->GetKey()", iter=itervar)

        def forLoopOverHashtable(hashtable, itervar, const=False):
            itermeth = "ConstIter" if const else "Iter"
            return StmtFor(
                init=ExprCode(
                    "auto ${itervar} = ${hashtable}.${itermeth}()",
                    itervar=itervar,
                    hashtable=hashtable,
                    itermeth=itermeth,
                ),
                cond=ExprCode("!${itervar}.Done()", itervar=itervar),
                update=ExprCode("${itervar}.Next()", itervar=itervar),
            )

        # Managed[T](Array& inout) const
        # const Array<T>& Managed() const
        for managed in ptype.manages:
            container = p.managedVar(managed, self.side)

            meth = MethodDefn(
                MethodDecl(
                    p.managedMethod(managed, self.side).name,
                    params=[
                        Decl(
                            _cxxArrayType(
                                p.managedCxxType(managed, self.side), ref=True
                            ),
                            "aArr",
                        )
                    ],
                    const=True,
                )
            )
            meth.addcode("${container}.ToArray(aArr);\n", container=container)

            refmeth = MethodDefn(
                MethodDecl(
                    p.managedMethod(managed, self.side).name,
                    params=[],
                    ret=p.managedVarType(managed, self.side, const=True, ref=True),
                    const=True,
                )
            )
            refmeth.addcode("return ${container};\n", container=container)

            self.cls.addstmts([meth, refmeth, Whitespace.NL])

        # AllManagedActors(Array& inout) const
        arrvar = ExprVar("arr__")
        managedmeth = MethodDefn(
            MethodDecl(
                "AllManagedActors",
                params=[
                    Decl(
                        _cxxArrayType(_refptr(_cxxLifecycleProxyType()), ref=True),
                        arrvar.name,
                    )
                ],
                methodspec=MethodSpec.OVERRIDE,
                const=True,
            )
        )

        # Count the number of managed actors, and allocate space in the output array.
        managedmeth.addcode(
            """
            uint32_t total = 0;
            """
        )
        for managed in ptype.manages:
            managedmeth.addcode(
                """
                total += ${container}.Count();
                """,
                container=p.managedVar(managed, self.side),
            )
        managedmeth.addcode(
            """
            arr__.SetCapacity(total);

            """
        )

        for managed in ptype.manages:
            managedmeth.addcode(
                """
                for (auto* key : ${container}) {
                    arr__.AppendElement(key->GetLifecycleProxy());
                }

                """,
                container=p.managedVar(managed, self.side),
            )

        self.cls.addstmts([managedmeth, Whitespace.NL])

        # OpenPEndpoint(...)/BindPEndpoint(...)
        for managed in ptype.manages:
            self.genManagedEndpoint(managed)

        # OnMessageReceived()/OnCallReceived()

        # save these away for use in message handler case stmts
        msgvar = ExprVar("msg__")
        self.msgvar = msgvar
        replyvar = ExprVar("reply__")
        self.replyvar = replyvar
        var = ExprVar("v__")
        self.var = var
        # for ctor recv cases, we can't read the actor ID into a PFoo*
        # because it doesn't exist on this side yet.  Use a "special"
        # actor handle instead
        handlevar = ExprVar("handle__")
        self.handlevar = handlevar

        msgtype = ExprCode("msg__.type()")
        self.asyncSwitch = StmtSwitch(msgtype)
        self.syncSwitch = None
        self.interruptSwitch = None
        if toplevel.isSync() or toplevel.isInterrupt():
            self.syncSwitch = StmtSwitch(msgtype)
            if toplevel.isInterrupt():
                self.interruptSwitch = StmtSwitch(msgtype)

        # Add a handler for the MANAGED_ENDPOINT_BOUND and
        # MANAGED_ENDPOINT_DROPPED message types for managed actors.
        if not ptype.isToplevel():
            clearawaitingmanagedendpointbind = """
                if (!mAwaitingManagedEndpointBind) {
                    NS_WARNING("Unexpected managed endpoint lifecycle message after actor bound!");
                    return MsgNotAllowed;
                }
                mAwaitingManagedEndpointBind = false;
                """
            self.asyncSwitch.addcase(
                CaseLabel("MANAGED_ENDPOINT_BOUND_MESSAGE_TYPE"),
                StmtBlock(
                    [
                        StmtCode(clearawaitingmanagedendpointbind),
                        StmtReturn(_Result.Processed),
                    ]
                ),
            )
            self.asyncSwitch.addcase(
                CaseLabel("MANAGED_ENDPOINT_DROPPED_MESSAGE_TYPE"),
                StmtBlock(
                    [
                        StmtCode(clearawaitingmanagedendpointbind),
                        *self.destroyActor(
                            None,
                            ExprVar.THIS,
                            why=_DestroyReason.ManagedEndpointDropped,
                        ),
                        StmtReturn(_Result.Processed),
                    ]
                ),
            )

        # implement Send*() methods and add dispatcher cases to
        # message switch()es
        for md in p.messageDecls:
            self.visitMessageDecl(md)

        # add default cases
        default = StmtCode(
            """
            return MsgNotKnown;
            """
        )
        self.asyncSwitch.addcase(DefaultLabel(), default)
        if toplevel.isSync() or toplevel.isInterrupt():
            self.syncSwitch.addcase(DefaultLabel(), default)
            if toplevel.isInterrupt():
                self.interruptSwitch.addcase(DefaultLabel(), default)

        self.cls.addstmts(self.implementManagerIface())

        def makeHandlerMethod(name, switch, hasReply, dispatches=False):
            params = [Decl(Type("Message", const=True, ref=True), msgvar.name)]
            if hasReply:
                params.append(Decl(Type("UniquePtr<Message>", ref=True), replyvar.name))

            method = MethodDefn(
                MethodDecl(
                    name,
                    methodspec=MethodSpec.OVERRIDE,
                    params=params,
                    ret=_Result.Type(),
                )
            )

            if not switch:
                method.addcode(
                    """
                    MOZ_ASSERT_UNREACHABLE("message protocol not supported");
                    return MsgNotKnown;
                    """
                )
                return method

            if dispatches:
                if hasReply:
                    ondeadactor = [StmtReturn(_Result.RouteError)]
                else:
                    ondeadactor = [
                        self.logMessage(
                            None, ExprAddrOf(msgvar), "Ignored message for dead actor"
                        ),
                        StmtReturn(_Result.Processed),
                    ]

                method.addcode(
                    """
                    int32_t route__ = ${msgvar}.routing_id();
                    if (MSG_ROUTING_CONTROL != route__) {
                        IProtocol* routed__ = Lookup(route__);
                        if (!routed__ || !routed__->GetLifecycleProxy()) {
                            $*{ondeadactor}
                        }

                        RefPtr<mozilla::ipc::ActorLifecycleProxy> proxy__ =
                            routed__->GetLifecycleProxy();
                        return proxy__->Get()->${name}($,{args});
                    }

                    """,
                    msgvar=msgvar,
                    ondeadactor=ondeadactor,
                    name=name,
                    args=[p.name for p in params],
                )

            # bug 509581: don't generate the switch stmt if there
            # is only the default case; MSVC doesn't like that
            if switch.nr_cases > 1:
                method.addstmt(switch)
            else:
                method.addstmt(StmtReturn(_Result.NotKnown))

            return method

        dispatches = ptype.isToplevel() and ptype.isManager()
        self.cls.addstmts(
            [
                makeHandlerMethod(
                    "OnMessageReceived",
                    self.asyncSwitch,
                    hasReply=False,
                    dispatches=dispatches,
                ),
                Whitespace.NL,
            ]
        )
        self.cls.addstmts(
            [
                makeHandlerMethod(
                    "OnMessageReceived",
                    self.syncSwitch,
                    hasReply=True,
                    dispatches=dispatches,
                ),
                Whitespace.NL,
            ]
        )
        self.cls.addstmts(
            [
                makeHandlerMethod(
                    "OnCallReceived",
                    self.interruptSwitch,
                    hasReply=True,
                    dispatches=dispatches,
                ),
                Whitespace.NL,
            ]
        )

        clearsubtreevar = ExprVar("ClearSubtree")

        if ptype.isToplevel():
            # OnChannelClose()
            onclose = MethodDefn(
                MethodDecl("OnChannelClose", methodspec=MethodSpec.OVERRIDE)
            )
            onclose.addcode(
                """
                DestroySubtree(NormalShutdown);
                ClearSubtree();
                DeallocShmems();
                if (GetLifecycleProxy()) {
                    GetLifecycleProxy()->Release();
                }
                """
            )
            self.cls.addstmts([onclose, Whitespace.NL])

            # OnChannelError()
            onerror = MethodDefn(
                MethodDecl("OnChannelError", methodspec=MethodSpec.OVERRIDE)
            )
            onerror.addcode(
                """
                DestroySubtree(AbnormalShutdown);
                ClearSubtree();
                DeallocShmems();
                if (GetLifecycleProxy()) {
                    GetLifecycleProxy()->Release();
                }
                """
            )
            self.cls.addstmts([onerror, Whitespace.NL])

        if ptype.isToplevel() and ptype.isInterrupt():
            processnative = MethodDefn(
                MethodDecl("ProcessNativeEventsInInterruptCall", ret=Type.VOID)
            )
            processnative.addcode(
                """
                #ifdef XP_WIN
                GetIPCChannel()->ProcessNativeEventsInInterruptCall();
                #else
                FatalError("This method is Windows-only");
                #endif
                """
            )

            self.cls.addstmts([processnative, Whitespace.NL])

        # private methods
        self.cls.addstmt(Label.PRIVATE)

        # ClearSubtree()
        clearsubtree = MethodDefn(MethodDecl(clearsubtreevar.name))
        for managed in ptype.manages:
            clearsubtree.addcode(
                """
                for (auto* key : ${container}) {
                    key->ClearSubtree();
                }
                for (auto* key : ${container}) {
                    // Recursively releasing ${container} kids.
                    auto* proxy = key->GetLifecycleProxy();
                    NS_IF_RELEASE(proxy);
                }
                ${container}.Clear();

                """,
                container=p.managedVar(managed, self.side),
            )

        # don't release our own IPC reference: either the manager will do it,
        # or we're toplevel
        self.cls.addstmts([clearsubtree, Whitespace.NL])

        if not ptype.isToplevel():
            self.cls.addstmts(
                [
                    StmtDecl(
                        Decl(Type.BOOL, "mAwaitingManagedEndpointBind"),
                        init=ExprLiteral.FALSE,
                    ),
                    Whitespace.NL,
                ]
            )

        for managed in ptype.manages:
            self.cls.addstmts(
                [
                    StmtDecl(
                        Decl(
                            p.managedVarType(managed, self.side),
                            p.managedVar(managed, self.side).name,
                        )
                    )
                ]
            )

    def genManagedEndpoint(self, managed):
        hereEp = "ManagedEndpoint<%s>" % _actorName(managed.name(), self.side)
        thereEp = "ManagedEndpoint<%s>" % _actorName(
            managed.name(), _otherSide(self.side)
        )

        actor = _HybridDecl(ipdl.type.ActorType(managed), "aActor")

        # ManagedEndpoint<PThere> OpenPEndpoint(PHere* aActor)
        openmeth = MethodDefn(
            MethodDecl(
                "Open%sEndpoint" % managed.name(),
                params=[
                    Decl(self.protocol.managedCxxType(managed, self.side), actor.name)
                ],
                ret=Type(thereEp),
            )
        )
        openmeth.addcode(
            """
            $*{bind}
            // Mark our actor as awaiting the other side to be bound. This will
            // be cleared when a `MANAGED_ENDPOINT_{DROPPED,BOUND}` message is
            // received.
            aActor->mAwaitingManagedEndpointBind = true;
            return ${thereEp}(mozilla::ipc::PrivateIPDLInterface(), aActor);
            """,
            bind=self.bindManagedActor(actor, errfn=ExprCall(ExprVar(thereEp))),
            thereEp=thereEp,
        )

        # void BindPEndpoint(ManagedEndpoint<PHere>&& aEndpoint, PHere* aActor)
        bindmeth = MethodDefn(
            MethodDecl(
                "Bind%sEndpoint" % managed.name(),
                params=[
                    Decl(Type(hereEp), "aEndpoint"),
                    Decl(self.protocol.managedCxxType(managed, self.side), actor.name),
                ],
                ret=Type.BOOL,
            )
        )
        bindmeth.addcode(
            """
            return aEndpoint.Bind(mozilla::ipc::PrivateIPDLInterface(), aActor, this, ${container});
            """,
            container=self.protocol.managedVar(managed, self.side),
        )

        self.cls.addstmts([openmeth, bindmeth, Whitespace.NL])

    def implementManagerIface(self):
        p = self.protocol
        protocolbase = Type("IProtocol", ptr=True)

        methods = []

        if p.decl.type.isToplevel():
            # FIXME: This used to be declared conditionally based on whether
            # shmem appeared somewhere in the protocol hierarchy, however that
            # caused issues due to Shmem instances hidden within custom C++
            # types.
            self.asyncSwitch.addcase(
                CaseLabel("SHMEM_CREATED_MESSAGE_TYPE"),
                self.genShmemCreatedHandler(),
            )
            self.asyncSwitch.addcase(
                CaseLabel("SHMEM_DESTROYED_MESSAGE_TYPE"),
                self.genShmemDestroyedHandler(),
            )

        # Keep track of types created with an INOUT ctor. We need to call
        # Register() or RegisterID() for them depending on the side the managee
        # is created.
        inoutCtorTypes = []
        for msg in p.messageDecls:
            msgtype = msg.decl.type
            if msgtype.isCtor() and msgtype.isInout():
                inoutCtorTypes.append(msgtype.constructedType())

        # all protocols share the "same" RemoveManagee() implementation
        pvar = ExprVar("aProtocolId")
        listenervar = ExprVar("aListener")
        removemanagee = MethodDefn(
            MethodDecl(
                p.removeManageeMethod().name,
                params=[
                    Decl(_protocolIdType(), pvar.name),
                    Decl(protocolbase, listenervar.name),
                ],
                methodspec=MethodSpec.OVERRIDE,
            )
        )

        if not len(p.managesStmts):
            removemanagee.addcode(
                """
                FatalError("unreached");
                return;
                """
            )
        else:
            switchontype = StmtSwitch(pvar)
            for managee in p.managesStmts:
                manageeipdltype = managee.decl.type
                manageecxxtype = _cxxBareType(
                    ipdl.type.ActorType(manageeipdltype), self.side
                )
                case = ExprCode(
                    """
                    {
                        ${manageecxxtype} actor = static_cast<${manageecxxtype}>(aListener);

                        const bool removed = ${container}.EnsureRemoved(actor);
                        MOZ_RELEASE_ASSERT(removed, "actor not managed by this!");

                        auto* proxy = actor->GetLifecycleProxy();
                        NS_IF_RELEASE(proxy);
                        return;
                    }
                    """,
                    manageecxxtype=manageecxxtype,
                    container=p.managedVar(manageeipdltype, self.side),
                )
                switchontype.addcase(CaseLabel(_protocolId(manageeipdltype).name), case)
            switchontype.addcase(
                DefaultLabel(),
                ExprCode(
                    """
                FatalError("unreached");
                return;
                """
                ),
            )
            removemanagee.addstmt(switchontype)

        # The `DeallocManagee` method is called for managed actors to trigger
        # deallocation when ActorLifecycleProxy is freed.
        deallocmanagee = MethodDefn(
            MethodDecl(
                p.deallocManageeMethod().name,
                params=[
                    Decl(_protocolIdType(), pvar.name),
                    Decl(protocolbase, listenervar.name),
                ],
                methodspec=MethodSpec.OVERRIDE,
            )
        )

        if not len(p.managesStmts):
            deallocmanagee.addcode(
                """
                FatalError("unreached");
                return;
                """
            )
        else:
            switchontype = StmtSwitch(pvar)
            for managee in p.managesStmts:
                manageeipdltype = managee.decl.type
                # Reference counted actor types don't have corresponding
                # `Dealloc` methods, as they are deallocated by releasing the
                # IPDL-held reference.
                if manageeipdltype.isRefcounted():
                    continue

                case = StmtCode(
                    """
                    ${concrete}->${dealloc}(static_cast<${type}>(aListener));
                    return;
                    """,
                    concrete=self.concreteThis(),
                    dealloc=_deallocMethod(manageeipdltype, self.side),
                    type=_cxxBareType(ipdl.type.ActorType(manageeipdltype), self.side),
                )
                switchontype.addcase(CaseLabel(_protocolId(manageeipdltype).name), case)
            switchontype.addcase(
                DefaultLabel(),
                StmtCode(
                    """
                FatalError("unreached");
                return;
                """
                ),
            )
            deallocmanagee.addstmt(switchontype)

        return methods + [removemanagee, deallocmanagee, Whitespace.NL]

    def genShmemCreatedHandler(self):
        assert self.protocol.decl.type.isToplevel()

        return StmtCode(
            """
            {
                if (!ShmemCreated(${msgvar})) {
                    return MsgPayloadError;
                }
                return MsgProcessed;
            }
            """,
            msgvar=self.msgvar,
        )

    def genShmemDestroyedHandler(self):
        assert self.protocol.decl.type.isToplevel()

        return StmtCode(
            """
            {
                if (!ShmemDestroyed(${msgvar})) {
                    return MsgPayloadError;
                }
                return MsgProcessed;
            }
            """,
            msgvar=self.msgvar,
        )

    # -------------------------------------------------------------------------
    # The next few functions are the crux of the IPDL code generator.
    # They generate code for all the nasty work of message
    # serialization/deserialization and dispatching handlers for
    # received messages.
    ##

    def concreteThis(self):
        implAttr = self.protocol.implAttribute(self.side)
        if implAttr == "virtual":
            return ExprVar.THIS

        if implAttr is None:
            assert self.protocol.name.startswith("P")
            className = self.protocol.name[1:] + self.side.capitalize()
        else:
            assert isinstance(implAttr, ipdl.ast.StringLiteral)
            className = implAttr.value

        return ExprCode("static_cast<${className}*>(this)", className=className)

    def thisCall(self, function, args):
        return ExprCall(ExprSelect(self.concreteThis(), "->", function), args=args)

    def visitMessageDecl(self, md):
        isctor = md.decl.type.isCtor()
        isdtor = md.decl.type.isDtor()
        decltype = md.decl.type
        sendmethod = None
        movesendmethod = None
        promisesendmethod = None
        recvlbl, recvcase = None, None

        def addRecvCase(lbl, case):
            if decltype.isAsync():
                self.asyncSwitch.addcase(lbl, case)
            elif decltype.isSync():
                self.syncSwitch.addcase(lbl, case)
            elif decltype.isInterrupt():
                self.interruptSwitch.addcase(lbl, case)
            else:
                assert 0

        if self.sendsMessage(md):
            isasync = decltype.isAsync()

            # NOTE: Don't generate helper ctors for refcounted types.
            #
            # Safety concerns around providing your own actor to a ctor (namely
            # that the return value won't be checked, and the argument will be
            # `delete`-ed) are less critical with refcounted actors, due to the
            # actor being held alive by the callsite.
            #
            # This allows refcounted actors to not implement crashing AllocPFoo
            # methods on the sending side.
            if isctor and not md.decl.type.constructedType().isRefcounted():
                self.cls.addstmts([self.genHelperCtor(md), Whitespace.NL])

            if isctor and isasync:
                sendmethod, (recvlbl, recvcase) = self.genAsyncCtor(md)
            elif isctor:
                sendmethod = self.genBlockingCtorMethod(md)
            elif isdtor and isasync:
                sendmethod, (recvlbl, recvcase) = self.genAsyncDtor(md)
            elif isdtor:
                sendmethod = self.genBlockingDtorMethod(md)
            elif isasync:
                (
                    sendmethod,
                    movesendmethod,
                    promisesendmethod,
                    (recvlbl, recvcase),
                ) = self.genAsyncSendMethod(md)
            else:
                sendmethod, movesendmethod = self.genBlockingSendMethod(md)

        # XXX figure out what to do here
        if isdtor and md.decl.type.constructedType().isToplevel():
            sendmethod = None

        if sendmethod is not None:
            self.cls.addstmts([sendmethod, Whitespace.NL])
        if movesendmethod is not None:
            self.cls.addstmts([movesendmethod, Whitespace.NL])
        if promisesendmethod is not None:
            self.cls.addstmts([promisesendmethod, Whitespace.NL])
        if recvcase is not None:
            addRecvCase(recvlbl, recvcase)
            recvlbl, recvcase = None, None

        if self.receivesMessage(md):
            if isctor:
                recvlbl, recvcase = self.genCtorRecvCase(md)
            elif isdtor:
                recvlbl, recvcase = self.genDtorRecvCase(md)
            else:
                recvlbl, recvcase = self.genRecvCase(md)

            # XXX figure out what to do here
            if isdtor and md.decl.type.constructedType().isToplevel():
                return

            addRecvCase(recvlbl, recvcase)

    def genAsyncCtor(self, md):
        actor = md.actorDecl()
        method = MethodDefn(self.makeSendMethodDecl(md))

        msgvar, stmts = self.makeMessage(md, errfnSendCtor)
        sendok, sendstmts = self.sendAsync(md, msgvar)

        method.addcode(
            """
            $*{bind}

            // Build our constructor message.
            $*{stmts}

            // Notify the other side about the newly created actor. This can
            // fail if our manager has already been destroyed.
            //
            // NOTE: If the send call fails due to toplevel channel teardown,
            // the `IProtocol::ChannelSend` wrapper absorbs the error for us,
            // so we don't tear down actors unexpectedly.
            $*{sendstmts}

            // Warn, destroy the actor, and return null if the message failed to
            // send. Otherwise, return the successfully created actor reference.
            if (!${sendok}) {
                NS_WARNING("Error sending ${actorname} constructor");
                $*{destroy}
                return nullptr;
            }
            return ${actor};
            """,
            bind=self.bindManagedActor(actor),
            stmts=stmts,
            sendstmts=sendstmts,
            sendok=sendok,
            destroy=self.destroyActor(
                md, actor.var(), why=_DestroyReason.FailedConstructor
            ),
            actor=actor.var(),
            actorname=actor.ipdltype.protocol.name() + self.side.capitalize(),
        )

        lbl = CaseLabel(md.pqReplyId())
        case = StmtBlock()
        case.addstmt(StmtReturn(_Result.Processed))
        # TODO not really sure what to do with async ctor "replies" yet.
        # destroy actor if there was an error?  tricky ...

        return method, (lbl, case)

    def genBlockingCtorMethod(self, md):
        actor = md.actorDecl()
        method = MethodDefn(self.makeSendMethodDecl(md))

        msgvar, stmts = self.makeMessage(md, errfnSendCtor)

        replyvar = self.replyvar
        sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar)
        replystmts = self.deserializeReply(
            md,
            replyvar,
            self.side,
            errfnSendCtor,
            errfnSentinel(ExprLiteral.NULL),
        )

        method.addcode(
            """
            $*{bind}

            // Build our constructor message.
            $*{stmts}

            // Synchronously send the constructor message to the other side. If
            // the send fails, e.g. due to the remote side shutting down, the
            // actor will be destroyed and potentially freed.
            UniquePtr<Message> ${replyvar};
            $*{sendstmts}

            if (!(${sendok})) {
                // Warn, destroy the actor and return null if the message
                // failed to send.
                NS_WARNING("Error sending constructor");
                $*{destroy}
                return nullptr;
            }

            $*{replystmts}
            return ${actor};
            """,
            bind=self.bindManagedActor(actor),
            stmts=stmts,
            replyvar=replyvar,
            sendstmts=sendstmts,
            sendok=sendok,
            destroy=self.destroyActor(
                md, actor.var(), why=_DestroyReason.FailedConstructor
            ),
            replystmts=replystmts,
            actor=actor.var(),
            actorname=actor.ipdltype.protocol.name() + self.side.capitalize(),
        )

        return method

    def bindManagedActor(self, actordecl, errfn=ExprLiteral.NULL, idexpr=None):
        actorproto = actordecl.ipdltype.protocol

        if idexpr is None:
            setManagerArgs = [ExprVar.THIS]
        else:
            setManagerArgs = [ExprVar.THIS, idexpr]

        return [
            StmtCode(
                """
            if (!${actor}) {
                NS_WARNING("Cannot bind null ${actorname} actor");
                return ${errfn};
            }

            ${actor}->SetManagerAndRegister($,{setManagerArgs});
            ${container}.Insert(${actor});
            """,
                actor=actordecl.var(),
                actorname=actorproto.name() + self.side.capitalize(),
                errfn=errfn,
                setManagerArgs=setManagerArgs,
                container=self.protocol.managedVar(actorproto, self.side),
            )
        ]

    def genHelperCtor(self, md):
        helperdecl = self.makeSendMethodDecl(md)
        helperdecl.params = helperdecl.params[1:]
        helper = MethodDefn(helperdecl)

        helper.addstmts(
            [
                self.callAllocActor(md, retsems="out", side=self.side),
                StmtReturn(
                    ExprCall(
                        ExprVar(helperdecl.name), args=md.makeCxxArgs(paramsems="move")
                    )
                ),
            ]
        )
        return helper

    def genAsyncDtor(self, md):
        actorvar = ExprVar("actor")
        method = MethodDefn(self.makeDtorMethodDecl(md, actorvar))

        method.addstmt(self.dtorPrologue(actorvar))

        msgvar, stmts = self.makeMessage(md, errfnSendDtor, actorvar)
        sendok, sendstmts = self.sendAsync(md, msgvar, actorvar)
        method.addstmts(
            stmts
            + sendstmts
            + [Whitespace.NL]
            + self.dtorEpilogue(md, actorvar)
            + [StmtReturn(sendok)]
        )

        lbl = CaseLabel(md.pqReplyId())
        case = StmtBlock()
        case.addstmt(StmtReturn(_Result.Processed))
        # TODO if the dtor is "inherently racy", keep the actor alive
        # until the other side acks

        return method, (lbl, case)

    def genBlockingDtorMethod(self, md):
        actorvar = ExprVar("actor")
        method = MethodDefn(self.makeDtorMethodDecl(md, actorvar))

        method.addstmt(self.dtorPrologue(actorvar))

        msgvar, stmts = self.makeMessage(md, errfnSendDtor, actorvar)

        replyvar = self.replyvar
        sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar, actorvar)
        method.addstmts(
            stmts
            + [Whitespace.NL, StmtDecl(Decl(Type("UniquePtr<Message>"), replyvar.name))]
            + sendstmts
        )

        destmts = self.deserializeReply(
            md, replyvar, self.side, errfnSend, errfnSentinel(), actorvar
        )
        ifsendok = StmtIf(ExprLiteral.FALSE)
        ifsendok.addifstmts(destmts)
        ifsendok.addifstmts(
            [Whitespace.NL, StmtExpr(ExprAssn(sendok, ExprLiteral.FALSE, "&="))]
        )

        method.addstmt(ifsendok)

        method.addstmts(
            self.dtorEpilogue(md, actorvar) + [Whitespace.NL, StmtReturn(sendok)]
        )

        return method

    def destroyActor(self, md, actorexpr, why=_DestroyReason.Deletion):
        if md and md.decl.type.isCtor():
            destroyedType = md.decl.type.constructedType()
        else:
            destroyedType = self.protocol.decl.type

        return [
            StmtCode(
                """
                IProtocol* mgr = ${actor}->Manager();
                ${actor}->DestroySubtree(${why});
                ${actor}->ClearSubtree();
                mgr->RemoveManagee(${protoId}, ${actor});
                """,
                actor=actorexpr,
                why=why,
                protoId=_protocolId(destroyedType),
            )
        ]

    def dtorPrologue(self, actorexpr):
        return StmtCode(
            """
            if (!${actor} || !${actor}->CanSend()) {
                NS_WARNING("Attempt to __delete__ missing or closed actor");
                return false;
            }
            """,
            actor=actorexpr,
        )

    def dtorEpilogue(self, md, actorexpr):
        return self.destroyActor(md, actorexpr)

    def genRecvAsyncReplyCase(self, md):
        lbl = CaseLabel(md.pqReplyId())
        case = StmtBlock()
        resolve, reason, prologue, desrej, desstmts = self.deserializeAsyncReply(
            md, self.side, errfnRecv, errfnSentinel(_Result.ValuError)
        )

        if len(md.returns) > 1:
            resolvetype = _tuple([d.bareType(self.side) for d in md.returns])
            resolvearg = ExprCall(
                ExprVar("std::make_tuple"), args=[ExprMove(p.var()) for p in md.returns]
            )
        else:
            resolvetype = md.returns[0].bareType(self.side)
            resolvearg = ExprMove(md.returns[0].var())

        case.addcode(
            """
            $*{prologue}

            UniquePtr<MessageChannel::UntypedCallbackHolder> untypedCallback =
                GetIPCChannel()->PopCallback(${msgvar}, Id());

            typedef MessageChannel::CallbackHolder<${resolvetype}> CallbackHolder;
            auto* callback = static_cast<CallbackHolder*>(untypedCallback.get());
            if (!callback) {
                FatalError("Error unknown callback");
                return MsgProcessingError;
            }

            if (${resolve}) {
                $*{desstmts}
                callback->Resolve(${resolvearg});
            } else {
                $*{desrej}
                callback->Reject(std::move(${reason}));
            }
            return MsgProcessed;
            """,
            prologue=prologue,
            msgvar=self.msgvar,
            resolve=resolve,
            resolvetype=resolvetype,
            desstmts=desstmts,
            resolvearg=resolvearg,
            desrej=desrej,
            reason=reason,
        )

        return (lbl, case)

    def genAsyncSendMethod(self, md):
        decl = self.makeSendMethodDecl(md)
        if "VirtualSendImpl" in md.attributes:
            decl.methodspec = MethodSpec.VIRTUAL
        method = MethodDefn(decl)
        msgvar, stmts = self.makeMessage(md, errfnSend)
        retvar, sendstmts = self.sendAsync(md, msgvar)

        method.addstmts(stmts + [Whitespace.NL] + sendstmts + [StmtReturn(retvar)])

        movemethod = None

        # Add the promise overload if we need one.
        if md.returns:
            decl = self.makeSendMethodDecl(md, promise=True)
            if "VirtualSendImpl" in md.attributes:
                decl.methodspec = MethodSpec.VIRTUAL
            promisemethod = MethodDefn(decl)
            stmts = self.sendAsyncWithPromise(md)
            promisemethod.addstmts(stmts)

            (lbl, case) = self.genRecvAsyncReplyCase(md)
        else:
            (promisemethod, lbl, case) = (None, None, None)

        return method, movemethod, promisemethod, (lbl, case)

    def genBlockingSendMethod(self, md):
        method = MethodDefn(self.makeSendMethodDecl(md))

        msgvar, serstmts = self.makeMessage(md, errfnSend)
        replyvar = self.replyvar

        sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar)
        failif = StmtIf(ExprNot(sendok))
        failif.addifstmt(StmtReturn.FALSE)

        desstmts = self.deserializeReply(
            md, replyvar, self.side, errfnSend, errfnSentinel()
        )

        method.addstmts(
            serstmts
            + [Whitespace.NL, StmtDecl(Decl(Type("UniquePtr<Message>"), replyvar.name))]
            + sendstmts
            + [failif]
            + desstmts
            + [Whitespace.NL, StmtReturn.TRUE]
        )

        movemethod = None

        return method, movemethod

    def genCtorRecvCase(self, md):
        lbl = CaseLabel(md.pqMsgId())
        case = StmtBlock()
        actorhandle = self.handlevar

        stmts = self.deserializeMessage(
            md, self.side, errfnRecv, errfnSent=errfnSentinel(_Result.ValuError)
        )

        idvar, saveIdStmts = self.saveActorId(md)
        case.addstmts(
            stmts
            + [
                StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[])
                for r in md.returns
            ]
            # alloc the actor, register it under the foreign ID
            + [self.callAllocActor(md, retsems="in", side=self.side)]
            + self.bindManagedActor(
                md.actorDecl(), errfn=_Result.ValuError, idexpr=_actorHId(actorhandle)
            )
            + [Whitespace.NL]
            + saveIdStmts
            + self.invokeRecvHandler(md)
            + self.makeReply(md, errfnRecv, idvar)
            + [Whitespace.NL, StmtReturn(_Result.Processed)]
        )

        return lbl, case

    def genDtorRecvCase(self, md):
        lbl = CaseLabel(md.pqMsgId())
        case = StmtBlock()

        stmts = self.deserializeMessage(
            md, self.side, errfnRecv, errfnSent=errfnSentinel(_Result.ValuError)
        )

        idvar, saveIdStmts = self.saveActorId(md)
        case.addstmts(
            stmts
            + [
                StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[])
                for r in md.returns
            ]
            + self.invokeRecvHandler(md)
            + [Whitespace.NL]
            + saveIdStmts
            + self.makeReply(md, errfnRecv, routingId=idvar)
            + [Whitespace.NL]
            + self.dtorEpilogue(md, ExprVar.THIS)
            + [Whitespace.NL, StmtReturn(_Result.Processed)]
        )

        return lbl, case

    def genRecvCase(self, md):
        lbl = CaseLabel(md.pqMsgId())
        case = StmtBlock()

        stmts = self.deserializeMessage(
            md, self.side, errfn=errfnRecv, errfnSent=errfnSentinel(_Result.ValuError)
        )

        idvar, saveIdStmts = self.saveActorId(md)
        declstmts = [
            StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[])
            for r in md.returns
        ]
        if md.decl.type.isAsync() and md.returns:
            declstmts = self.makeResolver(md, errfnRecv, routingId=idvar)
        case.addstmts(
            stmts
            + saveIdStmts
            + declstmts
            + self.invokeRecvHandler(md)
            + [Whitespace.NL]
            + self.makeReply(md, errfnRecv, routingId=idvar)
            + [StmtReturn(_Result.Processed)]
        )

        return lbl, case

    # helper methods

    def makeMessage(self, md, errfn, fromActor=None):
        msgvar = self.msgvar
        writervar = ExprVar("writer__")
        routingId = self.protocol.routingId(fromActor)
        this = fromActor or ExprVar.THIS

        stmts = (
            [
                StmtDecl(
                    Decl(Type("UniquePtr<IPC::Message>"), msgvar.name),
                    init=ExprCall(ExprVar(md.pqMsgCtorFunc()), args=[routingId]),
                ),
                StmtDecl(
                    Decl(Type("IPC::MessageWriter"), writervar.name),
                    initargs=[ExprDeref(msgvar), this],
                ),
            ]
            + [Whitespace.NL]
            + [
                _ParamTraits.checkedWrite(
                    p.ipdltype,
                    p.var(),
                    ExprAddrOf(writervar),
                    sentinelKey=p.name,
                )
                for p in md.params
            ]
            + [Whitespace.NL]
            + self.setMessageFlags(md, msgvar)
        )
        return msgvar, stmts

    def makeResolver(self, md, errfn, routingId):
        if routingId is None:
            routingId = self.protocol.routingId()
        if not md.decl.type.isAsync() or not md.hasReply():
            return []

        def paramValue(idx):
            assert idx < len(md.returns)
            if len(md.returns) > 1:
                return ExprCode("std::get<${idx}>(aParam)", idx=idx)
            return ExprVar("aParam")

        serializeParams = [
            _ParamTraits.checkedWrite(
                p.ipdltype,
                paramValue(idx),
                ExprAddrOf(ExprVar("writer__")),
                sentinelKey=p.name,
            )
            for idx, p in enumerate(md.returns)
        ]

        return [
            StmtCode(
                """
                UniquePtr<IPC::Message> ${replyvar}(${replyCtor}(${routingId}));
                ${replyvar}->set_seqno(${msgvar}.seqno());

                RefPtr<mozilla::ipc::IPDLResolverInner> resolver__ =
                    new mozilla::ipc::IPDLResolverInner(std::move(${replyvar}), this);

                ${resolvertype} resolver = [resolver__ = std::move(resolver__)](${resolveType} aParam) {
                    resolver__->Resolve([&] (IPC::Message* ${replyvar}, IProtocol* self__) {
                        IPC::MessageWriter writer__(*${replyvar}, self__);
                        $*{serializeParams}
                        ${logSendingReply}
                    });
                };
                """,
                msgvar=self.msgvar,
                resolvertype=Type(md.resolverName()),
                routingId=routingId,
                resolveType=_resolveType(md.returns, self.side),
                replyvar=self.replyvar,
                replyCtor=ExprVar(md.pqReplyCtorFunc()),
                serializeParams=serializeParams,
                logSendingReply=self.logMessage(
                    md,
                    self.replyvar,
                    "Sending reply ",
                    actor=ExprVar("self__"),
                ),
            )
        ]

    def makeReply(self, md, errfn, routingId):
        if routingId is None:
            routingId = self.protocol.routingId()
        # TODO special cases for async ctor/dtor replies
        if not md.decl.type.hasReply():
            return []
        if md.decl.type.isAsync() and md.decl.type.hasReply():
            return []

        replyvar = self.replyvar
        return (
            [
                StmtExpr(
                    ExprAssn(
                        replyvar,
                        ExprCall(ExprVar(md.pqReplyCtorFunc()), args=[routingId]),
                    )
                ),
                StmtDecl(
                    Decl(Type("IPC::MessageWriter"), "writer__"),
                    initargs=[ExprDeref(replyvar), ExprVar.THIS],
                ),
                Whitespace.NL,
            ]
            + [
                _ParamTraits.checkedWrite(
                    r.ipdltype,
                    r.var(),
                    ExprAddrOf(ExprVar("writer__")),
                    sentinelKey=r.name,
                )
                for r in md.returns
            ]
            + self.setMessageFlags(md, replyvar)
            + [self.logMessage(md, replyvar, "Sending reply ")]
        )

    def setMessageFlags(self, md, var, seqno=None):
        stmts = []

        if seqno:
            stmts.append(
                StmtExpr(ExprCall(ExprSelect(var, "->", "set_seqno"), args=[seqno]))
            )

        return stmts + [Whitespace.NL]

    def deserializeMessage(self, md, side, errfn, errfnSent):
        msgvar = self.msgvar
        msgexpr = ExprAddrOf(msgvar)
        readervar = ExprVar("reader__")
        isctor = md.decl.type.isCtor()
        stmts = [
            self.logMessage(md, msgexpr, "Received ", receiving=True),
            self.profilerLabel(md),
            Whitespace.NL,
        ]

        if 0 == len(md.params):
            return stmts

        start, reads = 0, []
        if isctor:
            # return the raw actor handle so that its ID can be used
            # to construct the "real" actor
            handlevar = self.handlevar
            handletype = Type("ActorHandle")
            reads = [
                _ParamTraits.checkedRead(
                    None,
                    handletype,
                    handlevar,
                    ExprAddrOf(readervar),
                    errfn,
                    "'%s'" % handletype.name,
                    sentinelKey="actor",
                    errfnSentinel=errfnSent,
                )
            ]
            start = 1

        def maybeTainted(p, side):
            if md.decl.type.tainted and "NoTaint" not in p.attributes:
                return Type("Tainted", T=p.bareType(side))
            return p.bareType(side)

        reads.extend(
            [
                _ParamTraits.checkedRead(
                    p.ipdltype,
                    maybeTainted(p, side),
                    p.var(),
                    ExprAddrOf(readervar),
                    errfn,
                    "'%s'" % p.ipdltype.name(),
                    sentinelKey=p.name,
                    errfnSentinel=errfnSent,
                )
                for p in md.params[start:]
            ]
        )

        stmts.extend(
            (
                [
                    StmtDecl(
                        Decl(Type("IPC::MessageReader"), readervar.name),
                        initargs=[msgvar, ExprVar.THIS],
                    )
                ]
                + [Whitespace.NL]
                + reads
                + [StmtCode("${reader}.EndRead();\n", reader=readervar)]
            )
        )

        return stmts

    def deserializeAsyncReply(self, md, side, errfn, errfnSent):
        msgvar = self.msgvar
        readervar = ExprVar("reader__")
        msgexpr = ExprAddrOf(msgvar)
        isctor = md.decl.type.isCtor()
        resolve = ExprVar("resolve__")
        reason = ExprVar("reason__")

        # NOTE: The `resolve__` and `reason__` parameters don't have sentinels,
        # as they are serialized by the IPDLResolverInner type in
        # ProtocolUtils.cpp rather than by generated code.
        desresolve = [
            StmtCode(
                """
                bool resolve__ = false;
                if (!IPC::ReadParam(&${readervar}, &resolve__)) {
                    FatalError("Error deserializing bool");
                    return MsgValueError;
                }
                """,
                readervar=readervar,
            ),
        ]
        desrej = [
            StmtCode(
                """
                ResponseRejectReason reason__{};
                if (!IPC::ReadParam(&${readervar}, &reason__)) {
                    FatalError("Error deserializing ResponseRejectReason");
                    return MsgValueError;
                }
                ${readervar}.EndRead();
                """,
                readervar=readervar,
            ),
        ]
        prologue = [
            self.logMessage(md, msgexpr, "Received ", receiving=True),
            self.profilerLabel(md),
            Whitespace.NL,
        ]

        if not md.returns:
            return prologue

        prologue.extend(
            [
                StmtDecl(
                    Decl(Type("IPC::MessageReader"), readervar.name),
                    initargs=[msgvar, ExprVar.THIS],
                )
            ]
            + desresolve
        )

        start, reads = 0, []
        if isctor:
            # return the raw actor handle so that its ID can be used
            # to construct the "real" actor
            handlevar = self.handlevar
            handletype = Type("ActorHandle")
            reads = [
                _ParamTraits.checkedRead(
                    None,
                    handletype,
                    handlevar,
                    ExprAddrOf(readervar),
                    errfn,
                    "'%s'" % handletype.name,
                    sentinelKey="actor",
                    errfnSentinel=errfnSent,
                )
            ]
            start = 1

        stmts = (
            reads
            + [
                _ParamTraits.checkedRead(
                    p.ipdltype,
                    p.bareType(side),
                    p.var(),
                    ExprAddrOf(readervar),
                    errfn,
                    "'%s'" % p.ipdltype.name(),
                    sentinelKey=p.name,
                    errfnSentinel=errfnSent,
                )
                for p in md.returns[start:]
            ]
            + [StmtCode("${reader}.EndRead();", reader=readervar)]
        )

        return resolve, reason, prologue, desrej, stmts

    def deserializeReply(self, md, replyexpr, side, errfn, errfnSentinel, actor=None):
        stmts = [
            Whitespace.NL,
            self.logMessage(md, replyexpr, "Received reply ", actor, receiving=True),
        ]
        if 0 == len(md.returns):
            return stmts

        def tempvar(r):
            return ExprVar(r.var().name + "__reply")

        readervar = ExprVar("reader__")
        stmts.extend(
            [
                Whitespace.NL,
                StmtDecl(
                    Decl(Type("IPC::MessageReader"), readervar.name),
                    initargs=[ExprDeref(self.replyvar), ExprVar.THIS],
                ),
            ]
            + [Whitespace.NL]
            + [
                _ParamTraits.checkedRead(
                    r.ipdltype,
                    r.bareType(side),
                    tempvar(r),
                    ExprAddrOf(readervar),
                    errfn,
                    "'%s'" % r.ipdltype.name(),
                    sentinelKey=r.name,
                    errfnSentinel=errfnSentinel,
                )
                for r in md.returns
            ]
            # Move-assign the values out of the variables created with
            # checkedRead into outparams.
            + [
                StmtExpr(ExprAssn(ExprDeref(r.var()), ExprMove(tempvar(r))))
                for r in md.returns
            ]
            + [StmtCode("${reader}.EndRead();", reader=readervar)]
        )

        return stmts

    def sendAsync(self, md, msgexpr, actor=None):
        sendok = ExprVar("sendok__")
        resolvefn = ExprVar("aResolve")
        rejectfn = ExprVar("aReject")

        stmts = [
            Whitespace.NL,
            self.logMessage(md, msgexpr, "Sending ", actor),
            self.profilerLabel(md),
        ]
        stmts.append(Whitespace.NL)

        # Generate the actual call expression.
        send = ExprVar("ChannelSend")
        if actor is not None:
            send = ExprSelect(actor, "->", send.name)
        if md.returns:
            stmts.append(
                StmtExpr(
                    ExprCall(
                        send,
                        args=[
                            ExprMove(msgexpr),
                            ExprVar(md.pqReplyId()),
                            ExprMove(resolvefn),
                            ExprMove(rejectfn),
                        ],
                    )
                )
            )
            retvar = None
        else:
            stmts.append(
                StmtDecl(
                    Decl(Type.BOOL, sendok.name),
                    init=ExprCall(send, args=[ExprMove(msgexpr)]),
                )
            )
            retvar = sendok

        return (retvar, stmts)

    def sendBlocking(self, md, msgexpr, replyexpr, actor=None):
        send = ExprVar("ChannelSend")
        if md.decl.type.isInterrupt():
            send = ExprVar("ChannelCall")
        if actor is not None:
            send = ExprSelect(actor, "->", send.name)

        sendok = ExprVar("sendok__")
        self.externalIncludes.add("mozilla/ProfilerMarkers.h")
        return (
            sendok,
            (
                [
                    Whitespace.NL,
                    self.logMessage(md, msgexpr, "Sending ", actor),
                    self.profilerLabel(md),
                ]
                + [
                    Whitespace.NL,
                    StmtDecl(Decl(Type.BOOL, sendok.name), init=ExprLiteral.FALSE),
                    StmtBlock(
                        [
                            StmtExpr(
                                ExprCall(
                                    ExprVar("AUTO_PROFILER_TRACING_MARKER"),
                                    [
                                        ExprLiteral.String("Sync IPC"),
                                        ExprLiteral.String(
                                            self.protocol.name
                                            + "::"
                                            + md.prettyMsgName()
                                        ),
                                        ExprVar("IPC"),
                                    ],
                                )
                            ),
                            StmtExpr(
                                ExprAssn(
                                    sendok,
                                    ExprCall(
                                        send,
                                        args=[ExprMove(msgexpr), ExprAddrOf(replyexpr)],
                                    ),
                                )
                            ),
                        ]
                    ),
                ]
            ),
        )

    def sendAsyncWithPromise(self, md):
        # Create a new promise, and forward to the callback send overload.
        promise = _makePromise(md.returns, self.side, resolver=True)

        if len(md.returns) > 1:
            resolvetype = _tuple([d.bareType(self.side) for d in md.returns])
        else:
            resolvetype = md.returns[0].bareType(self.side)

        resolve = ExprCode(
            """
            [promise__](${resolvetype}&& aValue) {
                promise__->Resolve(std::move(aValue), __func__);
            }
            """,
            resolvetype=resolvetype,
        )
        reject = ExprCode(
            """
            [promise__](ResponseRejectReason&& aReason) {
                promise__->Reject(std::move(aReason), __func__);
            }
            """,
            resolvetype=resolvetype,
        )

        args = [ExprMove(p.var()) for p in md.params] + [resolve, reject]
        stmt = StmtCode(
            """
            RefPtr<${promise}> promise__ = new ${promise}(__func__);
            promise__->UseDirectTaskDispatch(__func__);
            ${send}($,{args});
            return promise__;
            """,
            promise=promise,
            send=md.sendMethod(),
            args=args,
        )
        return [stmt]

    def callAllocActor(self, md, retsems, side):
        actortype = md.actorDecl().bareType(self.side)
        if md.decl.type.constructedType().isRefcounted():
            actortype.ptr = False
            actortype = _refptr(actortype)

        callalloc = self.thisCall(
            _allocMethod(md.decl.type.constructedType(), side),
            args=md.makeCxxArgs(retsems=retsems, retcallsems="out", implicit=False),
        )

        return StmtDecl(Decl(actortype, md.actorDecl().var().name), init=callalloc)

    def invokeRecvHandler(self, md):
        retsems = "in"
        if md.decl.type.isAsync() and md.returns:
            retsems = "resolver"
        okdecl = StmtDecl(
            Decl(Type("mozilla::ipc::IPCResult"), "__ok"),
            init=self.thisCall(
                md.recvMethod(),
                md.makeCxxArgs(
                    paramsems="move",
                    retsems=retsems,
                    retcallsems="out",
                ),
            ),
        )
        failif = StmtIf(ExprNot(ExprVar("__ok")))
        failif.addifstmts(
            [
                _protocolErrorBreakpoint("Handler returned error code!"),
                Whitespace(
                    "// Error handled in mozilla::ipc::IPCResult\n", indent=True
                ),
                StmtReturn(_Result.ProcessingError),
            ]
        )
        return [okdecl, failif]

    def makeDtorMethodDecl(self, md, actorvar):
        decl = self.makeSendMethodDecl(md)
        decl.params.insert(
            0,
            Decl(
                _cxxInType(
                    ipdl.type.ActorType(md.decl.type.constructedType()),
                    side=self.side,
                    direction="send",
                ),
                actorvar.name,
            ),
        )
        decl.methodspec = MethodSpec.STATIC
        return decl

    def makeSendMethodDecl(self, md, promise=False, paramsems="in"):
        implicit = md.decl.type.hasImplicitActorParam()
        if md.decl.type.isAsync() and md.returns:
            if promise:
                returnsems = "promise"
                rettype = _refptr(Type(md.promiseName()))
            else:
                returnsems = "callback"
                rettype = Type.VOID
        else:
            assert not promise
            returnsems = "out"
            rettype = Type.BOOL
        decl = MethodDecl(
            md.sendMethod(),
            params=md.makeCxxParams(
                paramsems,
                returnsems=returnsems,
                side=self.side,
                implicit=implicit,
                direction="send",
            ),
            warn_unused=(
                (self.side == "parent" and returnsems != "callback")
                or (md.decl.type.isCtor() and not md.decl.type.isAsync())
            ),
            ret=rettype,
        )
        if md.decl.type.isCtor():
            decl.ret = md.actorDecl().bareType(self.side)
        return decl

    def logMessage(self, md, msgptr, pfx, actor=None, receiving=False):
        actorname = _actorName(self.protocol.name, self.side)
        return StmtCode(
            """
            if (mozilla::ipc::LoggingEnabledFor(${actorname})) {
                mozilla::ipc::LogMessageForProtocol(
                    ${actorname},
                    ${actor}->ToplevelProtocol()->OtherPidMaybeInvalid(),
                    ${pfx},
                    ${msgptr}->type(),
                    mozilla::ipc::MessageDirection::${direction});
            }
            """,
            actorname=ExprLiteral.String(actorname),
            actor=actor or ExprVar.THIS,
            pfx=ExprLiteral.String(pfx),
            msgptr=msgptr,
            direction="eReceiving" if receiving else "eSending",
        )

    def profilerLabel(self, md):
        self.externalIncludes.add("mozilla/ProfilerLabels.h")
        return StmtCode(
            """
            AUTO_PROFILER_LABEL("${name}::${msgname}", OTHER);
            """,
            name=self.protocol.name,
            msgname=md.prettyMsgName(),
        )

    def saveActorId(self, md):
        idvar = ExprVar("id__")
        if md.decl.type.hasReply():
            # only save the ID if we're actually going to use it, to
            # avoid unused-variable warnings
            saveIdStmts = [
                StmtDecl(Decl(_actorIdType(), idvar.name), self.protocol.routingId())
            ]
        else:
            saveIdStmts = []
        return idvar, saveIdStmts


class _GenerateProtocolParentCode(_GenerateProtocolActorCode):
    def __init__(self):
        _GenerateProtocolActorCode.__init__(self, "parent")

    def sendsMessage(self, md):
        return not md.decl.type.isIn()

    def receivesMessage(self, md):
        return md.decl.type.isInout() or md.decl.type.isIn()


class _GenerateProtocolChildCode(_GenerateProtocolActorCode):
    def __init__(self):
        _GenerateProtocolActorCode.__init__(self, "child")

    def sendsMessage(self, md):
        return not md.decl.type.isOut()

    def receivesMessage(self, md):
        return md.decl.type.isInout() or md.decl.type.isOut()


# -----------------------------------------------------------------------------
# Utility passes
##


def _splitClassDeclDefn(cls):
    """Destructively split |cls| methods into declarations and
    definitions (if |not methodDecl.force_inline|).  Return classDecl,
    methodDefns."""
    defns = Block()

    for i, stmt in enumerate(cls.stmts):
        if isinstance(stmt, MethodDefn) and not stmt.decl.force_inline:
            decl, defn = _splitMethodDeclDefn(stmt, cls)
            cls.stmts[i] = StmtDecl(decl)
            if defn:
                defns.addstmts([defn, Whitespace.NL])

    return cls, defns


def _splitMethodDeclDefn(md, cls):
    # Pure methods have decls but no defns.
    if md.decl.methodspec == MethodSpec.PURE:
        return md.decl, None

    saveddecl = deepcopy(md.decl)
    md.decl.cls = cls
    # Don't emit method specifiers on method defns.
    md.decl.methodspec = MethodSpec.NONE
    md.decl.warn_unused = False
    md.decl.only_for_definition = True
    for param in md.decl.params:
        if isinstance(param, Param):
            param.default = None
    return saveddecl, md


def _splitFuncDeclDefn(fun):
    assert not fun.decl.force_inline
    return StmtDecl(fun.decl), fun
