/*
 * Copyright (c) 2007,2008,2009 INRIA, UDCAST
 * Copyright (c) 2011 Centre Tecnologic de Telecomunicacions de Catalunya (CTTC)
 *
 * SPDX-License-Identifier: GPL-2.0-only
 *
 * The original version of UdpClient is by  Amine Ismail
 * <amine.ismail@sophia.inria.fr> <amine.ismail@udcast.com>
 * The rest of the code (including modifying UdpClient into
 *  NrEpsBearerTagUdpClient) is by Nicola Baldo <nbaldo@cttc.es>
 */

#include "nr-test-entities.h"

#include "ns3/arp-cache.h"
#include "ns3/boolean.h"
#include "ns3/config.h"
#include "ns3/csma-helper.h"
#include "ns3/inet-socket-address.h"
#include "ns3/internet-stack-helper.h"
#include "ns3/ipv4-address-helper.h"
#include "ns3/ipv4-interface.h"
#include "ns3/ipv4-static-routing-helper.h"
#include "ns3/ipv4-static-routing.h"
#include "ns3/log.h"
#include "ns3/mac48-address.h"
#include "ns3/nr-epc-gnb-application.h"
#include "ns3/nr-eps-bearer-tag.h"
#include "ns3/nr-point-to-point-epc-helper.h"
#include "ns3/packet-sink-helper.h"
#include "ns3/packet-sink.h"
#include "ns3/point-to-point-helper.h"
#include "ns3/seq-ts-header.h"
#include "ns3/simulator.h"
#include "ns3/test.h"
#include "ns3/uinteger.h"

using namespace ns3;

NS_LOG_COMPONENT_DEFINE("NrEpcTestS1uUplink");

/**
 * @ingroup nr-test
 *
 * A Udp client. Sends UDP packet carrying sequence number and time
 * stamp but also including the NrEpsBearerTag. This tag is normally
 * generated by the NrGnbNetDevice when forwarding packet in the
 * uplink. But in this test we don't have the NrGnbNetDevice, because
 * we test the S1-U interface with simpler devices to make sure it
 * just works.
 *
 */
class NrEpsBearerTagUdpClient : public Application
{
  public:
    /**
     * @brief Get the type ID.
     * @return the object TypeId
     */
    static TypeId GetTypeId();

    NrEpsBearerTagUdpClient();
    /**
     * Constructor
     *
     * @param rnti the RNTI
     * @param bid the BID
     */
    NrEpsBearerTagUdpClient(uint16_t rnti, uint8_t bid);

    ~NrEpsBearerTagUdpClient() override;

    /**
     * @brief set the remote address and port
     * @param ip remote IP address
     * @param port remote port
     */
    void SetRemote(Ipv4Address ip, uint16_t port);

  protected:
    void DoDispose() override;

  private:
    void StartApplication() override;
    void StopApplication() override;

    /**
     * @brief Schedule transmit function
     * @param dt the delta time
     */
    void ScheduleTransmit(Time dt);
    /// Send function
    void Send();

    uint32_t m_count; ///< maximum number of packets to send
    Time m_interval;  ///< the time between packets
    uint32_t m_size;  ///< the size of packets generated

    uint32_t m_sent;           ///< number of packets sent
    Ptr<Socket> m_socket;      ///< the socket
    Ipv4Address m_peerAddress; ///< the peer address of the outbound packets
    uint16_t m_peerPort;       ///< the destination port of the outbound packets
    EventId m_sendEvent;       ///< the send event

    uint16_t m_rnti; ///< the RNTI
    uint8_t m_bid;   ///< the bearer identificator
};

TypeId
NrEpsBearerTagUdpClient::GetTypeId()
{
    static TypeId tid =
        TypeId("ns3::NrEpsBearerTagUdpClient")
            .SetParent<Application>()
            .AddConstructor<NrEpsBearerTagUdpClient>()
            .AddAttribute(
                "MaxPackets",
                "The maximum number of packets the application will send (zero means infinite)",
                UintegerValue(100),
                MakeUintegerAccessor(&NrEpsBearerTagUdpClient::m_count),
                MakeUintegerChecker<uint32_t>())
            .AddAttribute("Interval",
                          "The time to wait between packets",
                          TimeValue(Seconds(1.0)),
                          MakeTimeAccessor(&NrEpsBearerTagUdpClient::m_interval),
                          MakeTimeChecker())
            .AddAttribute("RemoteAddress",
                          "The destination Ipv4Address of the outbound packets",
                          Ipv4AddressValue(),
                          MakeIpv4AddressAccessor(&NrEpsBearerTagUdpClient::m_peerAddress),
                          MakeIpv4AddressChecker())
            .AddAttribute("RemotePort",
                          "The destination port of the outbound packets",
                          UintegerValue(100),
                          MakeUintegerAccessor(&NrEpsBearerTagUdpClient::m_peerPort),
                          MakeUintegerChecker<uint16_t>())
            .AddAttribute("PacketSize",
                          "Size of packets generated. The minimum packet size is 12 bytes which is "
                          "the size of the header carrying the sequence number and the time stamp.",
                          UintegerValue(1024),
                          MakeUintegerAccessor(&NrEpsBearerTagUdpClient::m_size),
                          MakeUintegerChecker<uint32_t>());
    return tid;
}

NrEpsBearerTagUdpClient::NrEpsBearerTagUdpClient()
    : m_rnti(0),
      m_bid(0)
{
    NS_LOG_FUNCTION_NOARGS();
    m_sent = 0;
    m_socket = nullptr;
    m_sendEvent = EventId();
}

NrEpsBearerTagUdpClient::NrEpsBearerTagUdpClient(uint16_t rnti, uint8_t bid)
    : m_rnti(rnti),
      m_bid(bid)
{
    NS_LOG_FUNCTION_NOARGS();
    m_sent = 0;
    m_socket = nullptr;
    m_sendEvent = EventId();
}

NrEpsBearerTagUdpClient::~NrEpsBearerTagUdpClient()
{
    NS_LOG_FUNCTION_NOARGS();
}

void
NrEpsBearerTagUdpClient::SetRemote(Ipv4Address ip, uint16_t port)
{
    m_peerAddress = ip;
    m_peerPort = port;
}

void
NrEpsBearerTagUdpClient::DoDispose()
{
    NS_LOG_FUNCTION_NOARGS();
    Application::DoDispose();
}

void
NrEpsBearerTagUdpClient::StartApplication()
{
    NS_LOG_FUNCTION_NOARGS();

    if (!m_socket)
    {
        TypeId tid = TypeId::LookupByName("ns3::UdpSocketFactory");
        m_socket = Socket::CreateSocket(GetNode(), tid);
        m_socket->Bind();
        m_socket->Connect(InetSocketAddress(m_peerAddress, m_peerPort));
    }

    m_socket->SetRecvCallback(MakeNullCallback<void, Ptr<Socket>>());
    m_sendEvent = Simulator::Schedule(Seconds(0.0), &NrEpsBearerTagUdpClient::Send, this);
}

void
NrEpsBearerTagUdpClient::StopApplication()
{
    NS_LOG_FUNCTION_NOARGS();
    Simulator::Cancel(m_sendEvent);
}

void
NrEpsBearerTagUdpClient::Send()
{
    NS_LOG_FUNCTION_NOARGS();
    NS_ASSERT(m_sendEvent.IsExpired());
    SeqTsHeader seqTs;
    seqTs.SetSeq(m_sent);
    Ptr<Packet> p = Create<Packet>(m_size - (8 + 4)); // 8+4 : the size of the seqTs header
    p->AddHeader(seqTs);

    NrEpsBearerTag tag(m_rnti, m_bid);
    p->AddPacketTag(tag);

    if ((m_socket->Send(p)) >= 0)
    {
        ++m_sent;
        NS_LOG_INFO("TraceDelay TX " << m_size << " bytes to " << m_peerAddress << " Uid: "
                                     << p->GetUid() << " Time: " << (Simulator::Now()).As(Time::S));
    }
    else
    {
        NS_LOG_INFO("Error while sending " << m_size << " bytes to " << m_peerAddress);
    }

    if (m_sent < m_count || m_count == 0)
    {
        m_sendEvent = Simulator::Schedule(m_interval, &NrEpsBearerTagUdpClient::Send, this);
    }
}

/**
 * @ingroup nr-test
 *
 * @brief Custom test structure to hold information of data transmitted in the uplink per UE
 */
struct NrUeUlTestData
{
    /**
     * Constructor
     *
     * @param n number of packets
     * @param s packet size
     * @param r the RNTI
     * @param l the BID
     */
    NrUeUlTestData(uint32_t n, uint32_t s, uint16_t r, uint8_t l);

    uint32_t numPkts; ///< the number of packets sent
    uint32_t pktSize; ///< the packet size
    uint16_t rnti;    ///< the RNTI
    uint8_t bid;      ///< the BID

    Ptr<PacketSink> serverApp;  ///< the server application
    Ptr<Application> clientApp; ///< the client application
};

NrUeUlTestData::NrUeUlTestData(uint32_t n, uint32_t s, uint16_t r, uint8_t l)
    : numPkts(n),
      pktSize(s),
      rnti(r),
      bid(l)
{
}

/**
 * @ingroup nr-test
 *
 * @brief Custom structure containing information about data sent in the uplink
 * of eNodeB. Includes the information of the data sent in the uplink per UE.
 */
struct GnbUlTestData
{
    std::vector<NrUeUlTestData> ues; ///< the list of UEs
};

/**
 * @ingroup nr-test
 *
 * @brief NrEpcS1uUlTestCase class
 */
class NrEpcS1uUlTestCase : public TestCase
{
  public:
    /**
     * Constructor
     *
     * @param name the reference name
     * @param v the list of UE lists
     */
    NrEpcS1uUlTestCase(std::string name, std::vector<GnbUlTestData> v);
    ~NrEpcS1uUlTestCase() override;

  private:
    void DoRun() override;
    std::vector<GnbUlTestData> m_gnbUlTestData; ///< gNB UL test data
};

NrEpcS1uUlTestCase::NrEpcS1uUlTestCase(std::string name, std::vector<GnbUlTestData> v)
    : TestCase(name),
      m_gnbUlTestData(v)
{
}

NrEpcS1uUlTestCase::~NrEpcS1uUlTestCase()
{
}

void
NrEpcS1uUlTestCase::DoRun()
{
    Ptr<NrPointToPointEpcHelper> nrEpcHelper = CreateObject<NrPointToPointEpcHelper>();
    Ptr<Node> pgw = nrEpcHelper->GetPgwNode();

    // allow jumbo packets
    Config::SetDefault("ns3::CsmaNetDevice::Mtu", UintegerValue(30000));
    Config::SetDefault("ns3::PointToPointNetDevice::Mtu", UintegerValue(30000));
    nrEpcHelper->SetAttribute("S1uLinkMtu", UintegerValue(30000));

    // Create a single RemoteHost
    NodeContainer remoteHostContainer;
    remoteHostContainer.Create(1);
    Ptr<Node> remoteHost = remoteHostContainer.Get(0);
    InternetStackHelper internet;
    internet.Install(remoteHostContainer);

    // Create the internet
    PointToPointHelper p2ph;
    p2ph.SetDeviceAttribute("DataRate", DataRateValue(DataRate("100Gb/s")));
    NetDeviceContainer internetDevices = p2ph.Install(pgw, remoteHost);
    Ipv4AddressHelper ipv4h;
    ipv4h.SetBase("1.0.0.0", "255.0.0.0");
    Ipv4InterfaceContainer internetNodesIpIfaceContainer = ipv4h.Assign(internetDevices);

    // setup default gateway for the remote hosts
    Ipv4StaticRoutingHelper ipv4RoutingHelper;
    Ptr<Ipv4StaticRouting> remoteHostStaticRouting =
        ipv4RoutingHelper.GetStaticRouting(remoteHost->GetObject<Ipv4>());

    // hardcoded UE addresses for now
    remoteHostStaticRouting->AddNetworkRouteTo(Ipv4Address("7.0.0.0"),
                                               Ipv4Mask("255.255.255.0"),
                                               1);

    uint16_t udpSinkPort = 1234;

    NodeContainer gnbs;
    uint16_t cellIdCounter = 0;
    uint64_t imsiCounter = 0;

    for (auto gnbit = m_gnbUlTestData.begin(); gnbit < m_gnbUlTestData.end(); ++gnbit)
    {
        Ptr<Node> gnb = CreateObject<Node>();
        gnbs.Add(gnb);

        // we test EPC without LTE, hence we use:
        // 1) a CSMA network to simulate the cell
        // 2) a raw socket opened on the CSMA device to simulate the NR socket

        uint16_t cellId = ++cellIdCounter;

        NodeContainer ues;
        ues.Create(gnbit->ues.size());

        NodeContainer cell;
        cell.Add(ues);
        cell.Add(gnb);

        CsmaHelper csmaCell;
        NetDeviceContainer cellDevices = csmaCell.Install(cell);

        // the eNB's CSMA NetDevice acting as an NR NetDevice.
        Ptr<NetDevice> gnbDevice = cellDevices.Get(cellDevices.GetN() - 1);

        // Note that the NrEpcGnbApplication won't care of the actual NetDevice type
        std::vector<uint16_t> cellIds;
        cellIds.push_back(cellId);
        nrEpcHelper->AddGnb(gnb, gnbDevice, cellIds);

        // Plug test RRC entity
        Ptr<NrEpcGnbApplication> gnbApp = gnb->GetApplication(0)->GetObject<NrEpcGnbApplication>();
        NS_ASSERT_MSG(gnbApp, "cannot retrieve NrEpcGnbApplication");
        Ptr<NrEpcTestRrc> rrc = CreateObject<NrEpcTestRrc>();
        gnb->AggregateObject(rrc);
        rrc->SetS1SapProvider(gnbApp->GetS1SapProvider());
        gnbApp->SetS1SapUser(rrc->GetS1SapUser());

        // we install the IP stack on UEs only
        InternetStackHelper internet;
        internet.Install(ues);

        // assign IP address to UEs, and install applications
        for (uint32_t u = 0; u < ues.GetN(); ++u)
        {
            Ptr<NetDevice> ueNrDevice = cellDevices.Get(u);
            Ipv4InterfaceContainer ueIpIface =
                nrEpcHelper->AssignUeIpv4Address(NetDeviceContainer(ueNrDevice));

            Ptr<Node> ue = ues.Get(u);

            // disable IP Forwarding on the UE. This is because we use
            // CSMA broadcast MAC addresses for this test. The problem
            // won't happen with a NrUeNetDevice.
            Ptr<Ipv4> ueIpv4 = ue->GetObject<Ipv4>();
            ueIpv4->SetAttribute("IpForward", BooleanValue(false));

            // tell the UE to route all packets to the GW
            Ptr<Ipv4StaticRouting> ueStaticRouting = ipv4RoutingHelper.GetStaticRouting(ueIpv4);
            Ipv4Address gwAddr = nrEpcHelper->GetUeDefaultGatewayAddress();
            NS_LOG_INFO("GW address: " << gwAddr);
            ueStaticRouting->SetDefaultRoute(gwAddr, 1);

            // since the UEs in this test use CSMA with IP enabled, and
            // the gNB uses CSMA but without IP, we fool the UE's ARP
            // cache into thinking that the IP address of the GW can be
            // reached by sending a CSMA packet to the broadcast
            // address, so the gNB will get it.
            int32_t ueNrIpv4IfIndex = ueIpv4->GetInterfaceForDevice(ueNrDevice);
            Ptr<Ipv4L3Protocol> ueIpv4L3Protocol = ue->GetObject<Ipv4L3Protocol>();
            Ptr<Ipv4Interface> ueNrIpv4Iface = ueIpv4L3Protocol->GetInterface(ueNrIpv4IfIndex);
            Ptr<ArpCache> ueArpCache = ueNrIpv4Iface->GetArpCache();
            ueArpCache->SetAliveTimeout(Seconds(1000));
            ArpCache::Entry* arpCacheEntry = ueArpCache->Add(gwAddr);
            arpCacheEntry->SetMacAddress(Mac48Address::GetBroadcast());
            arpCacheEntry->MarkPermanent();

            PacketSinkHelper packetSinkHelper(
                "ns3::UdpSocketFactory",
                InetSocketAddress(Ipv4Address::GetAny(), udpSinkPort));
            ApplicationContainer sinkApp = packetSinkHelper.Install(remoteHost);
            sinkApp.Start(Seconds(1.0));
            sinkApp.Stop(Seconds(10.0));
            gnbit->ues[u].serverApp = sinkApp.Get(0)->GetObject<PacketSink>();

            Time interPacketInterval = Seconds(0.01);
            Ptr<NrEpsBearerTagUdpClient> client =
                CreateObject<NrEpsBearerTagUdpClient>(gnbit->ues[u].rnti, gnbit->ues[u].bid);
            client->SetAttribute("RemoteAddress",
                                 Ipv4AddressValue(internetNodesIpIfaceContainer.GetAddress(1)));
            client->SetAttribute("RemotePort", UintegerValue(udpSinkPort));
            client->SetAttribute("MaxPackets", UintegerValue(gnbit->ues[u].numPkts));
            client->SetAttribute("Interval", TimeValue(interPacketInterval));
            client->SetAttribute("PacketSize", UintegerValue(gnbit->ues[u].pktSize));
            ue->AddApplication(client);
            ApplicationContainer clientApp;
            clientApp.Add(client);
            clientApp.Start(Seconds(2.0));
            clientApp.Stop(Seconds(10.0));
            gnbit->ues[u].clientApp = client;

            uint64_t imsi = ++imsiCounter;
            nrEpcHelper->AddUe(ueNrDevice, imsi);
            nrEpcHelper->ActivateEpsBearer(ueNrDevice,
                                           imsi,
                                           NrEpcTft::Default(),
                                           NrEpsBearer(NrEpsBearer::NGBR_VIDEO_TCP_DEFAULT));
            Simulator::Schedule(MilliSeconds(10),
                                &NrEpcGnbS1SapProvider::InitialUeMessage,
                                gnbApp->GetS1SapProvider(),
                                imsi,
                                gnbit->ues[u].rnti);
            // need this since all sinks are installed in the same node
            ++udpSinkPort;
        }
    }

    Simulator::Run();

    for (auto gnbit = m_gnbUlTestData.begin(); gnbit < m_gnbUlTestData.end(); ++gnbit)
    {
        for (auto ueit = gnbit->ues.begin(); ueit < gnbit->ues.end(); ++ueit)
        {
            NS_TEST_ASSERT_MSG_EQ(ueit->serverApp->GetTotalRx(),
                                  (ueit->numPkts) * (ueit->pktSize),
                                  "wrong total received bytes");
        }
    }

    Simulator::Destroy();
}

/**
 * Test that the S1-U interface implementation works correctly
 */
class NrEpcS1uUlTestSuite : public TestSuite
{
  public:
    NrEpcS1uUlTestSuite();

} g_NrEpcS1uUlTestSuiteInstance;

NrEpcS1uUlTestSuite::NrEpcS1uUlTestSuite()
    : TestSuite("nr-epc-s1u-uplink", Type::SYSTEM)
{
    std::vector<GnbUlTestData> v1;
    GnbUlTestData e1;
    NrUeUlTestData f1(1, 100, 1, 1);
    e1.ues.push_back(f1);
    v1.push_back(e1);
    AddTestCase(new NrEpcS1uUlTestCase("1 eNB, 1UE", v1), TestCase::Duration::QUICK);

    std::vector<GnbUlTestData> v2;
    GnbUlTestData e2;
    NrUeUlTestData f2_1(1, 100, 1, 1);
    e2.ues.push_back(f2_1);
    NrUeUlTestData f2_2(2, 200, 2, 1);
    e2.ues.push_back(f2_2);
    v2.push_back(e2);
    AddTestCase(new NrEpcS1uUlTestCase("1 eNB, 2UEs", v2), TestCase::Duration::QUICK);

    std::vector<GnbUlTestData> v3;
    v3.push_back(e1);
    v3.push_back(e2);
    AddTestCase(new NrEpcS1uUlTestCase("2 eNBs", v3), TestCase::Duration::QUICK);

    GnbUlTestData e3;
    NrUeUlTestData f3_1(3, 50, 1, 1);
    e3.ues.push_back(f3_1);
    NrUeUlTestData f3_2(5, 1472, 2, 1);
    e3.ues.push_back(f3_2);
    NrUeUlTestData f3_3(1, 1, 3, 1);
    e3.ues.push_back(f3_2);
    std::vector<GnbUlTestData> v4;
    v4.push_back(e3);
    v4.push_back(e1);
    v4.push_back(e2);
    AddTestCase(new NrEpcS1uUlTestCase("3 eNBs", v4), TestCase::Duration::QUICK);

    std::vector<GnbUlTestData> v5;
    GnbUlTestData e5;
    NrUeUlTestData f5(10, 3000, 1, 1);
    e5.ues.push_back(f5);
    v5.push_back(e5);
    AddTestCase(new NrEpcS1uUlTestCase("1 eNB, 10 pkts 3000 bytes each", v5),
                TestCase::Duration::QUICK);

    std::vector<GnbUlTestData> v6;
    GnbUlTestData e6;
    NrUeUlTestData f6(50, 3000, 1, 1);
    e6.ues.push_back(f6);
    v6.push_back(e6);
    AddTestCase(new NrEpcS1uUlTestCase("1 eNB, 50 pkts 3000 bytes each", v6),
                TestCase::Duration::QUICK);

    std::vector<GnbUlTestData> v7;
    GnbUlTestData e7;
    NrUeUlTestData f7(10, 15000, 1, 1);
    e7.ues.push_back(f7);
    v7.push_back(e7);
    AddTestCase(new NrEpcS1uUlTestCase("1 eNB, 10 pkts 15000 bytes each", v7),
                TestCase::Duration::QUICK);

    std::vector<GnbUlTestData> v8;
    GnbUlTestData e8;
    NrUeUlTestData f8(100, 15000, 1, 1);
    e8.ues.push_back(f8);
    v8.push_back(e8);
    AddTestCase(new NrEpcS1uUlTestCase("1 eNB, 100 pkts 15000 bytes each", v8),
                TestCase::Duration::QUICK);
}
