#include <stdlib.h>
#include <time.h>
#include <string.h>
#include <stdio.h>
#include <tos.h>
#include "pktdrv.h"
#include "ip.h"
#include "icmp.h"
#include "timer.h"
#include "udp.h"
#include "inetcust.h"
#include "mbuf.h"

#include "nettrace.h"

#define Bconws(x) dpy = x;while(*dpy)(Bconout(2,*dpy++))

static char *dpy;

#define noDEBUGDROP
#define noDEBUGRECV
#define noDEBUGOPN
#define noDEBUGCLS
#define noDEBUGWR

#define min(a,b) ((a) < (b) ? (a) : (b))

UDP_CTL  *udp_list = NULL;
UDP_CTL **udp_tab;
int		  udp_tablen = 0;
extern char *udp_buffers;
int udp_handler(PACKET *,int,INADDR);
int udp_du_handler(IP *);
void udp_droppacket(TIMER);
u_short udp_newport(void);
long udp_counts[2] = {0,0};

int udp_init(void)
{
register int i;
	udp_list = NULL;
	udp_tablen = UDP_MAXPORTS;
	udp_tab = (UDP_CTL **)getmem((size_t)udp_tablen*sizeof(UDP_CTL *));
	if(!udp_tab) return(FALSE);
	for(i=0; i<udp_tablen; i++)
		udp_tab[i] = NULL;
	if(!ip_open(IP_UDP,udp_handler,udp_du_handler))
	{
		freemem(udp_tab);
		return(FALSE);
	}
	return(TRUE);
}

int udp_exit(void)
{
int i;

	if(udp_tab)
	{
		for(i=0;i<udp_tablen;i++)
			if(udp_tab[i]) udp_close(i);
		free(udp_tab);
		udp_tab = NULL;
		udp_tablen = 0;
	}
	return(ip_close(IP_UDP));
}


int udp_handler(PACKET *pkt,int len,INADDR fhost)
{
UDP_CTL *p_udpctl;
IP 		*p_ip;
register UDP *p_udp;
TCP_PSEUDO tcp_ph;
u_short csum;
unsigned plen;

	/* First let's verify that it's a valid UDP packet. */
	p_ip = ip_head(pkt);
	p_udp = udp_head(p_ip);
	
	plen = p_udp->length;

	if(plen > len)
	{
#ifdef DEBUGRECV
printf("invalid packet %u,%u\n",plen,len);
#endif
		ip_free(pkt);
		return(FALSE);
	}

	csum = p_udp->chksum;
	if(csum)
	{
		tcp_ph.src = fhost;
		tcp_ph.dst = p_ip->dst_inaddr;
		tcp_ph.protocol = IP_UDP;
		tcp_ph.length  = len;

		p_udp->chksum = ~chksum((u_short *)&tcp_ph,(u_short)sizeof(TCP_PSEUDO),0);
		p_udp->chksum = chksum((u_short *)p_udp,len,0);
		if(csum != p_udp->chksum && !(csum == 0xffff && p_udp->chksum == 0))
		{
#ifdef DEBUGRECV
printf("bad checksum %04x->%04x\n",csum,p_udp->chksum);
#endif
			ip_free(pkt);
			return(FALSE);
		}
	}

	/* ok, accept it. run through the demux table and try to upcall it */
	for(p_udpctl = udp_list; p_udpctl; p_udpctl = p_udpctl->next)
	{
		if(p_udpctl->lcl_port && (p_udpctl->lcl_port != p_udp->dst_port))
			continue;

		if(p_udpctl->upcall)
		{
#ifdef DEBUGRECV
printf("pkt [%d] from %8lx.%u to port %u\n",len,fhost,p_udp->src_port,p_udpctl->lcl_port);
#endif
			p_udpctl->fhost = fhost;
			p_udpctl->fgn_port = p_udp->src_port;
			p_udpctl->data_len = len - sizeof(UDP);
			p_udpctl->data = (char *)p_udp+sizeof(UDP);
			if(p_udpctl->pkt)
			{
#ifdef DEBUGRECV
printf("packet overrun\n");
#endif
				ip_free(p_udpctl->pkt);			/* throw away old packet */
				p_udpctl->udp_err = UDP_OVR;	/* signal overrun */
			}
			else
				p_udpctl->udp_err = UDP_OK;
			p_udpctl->pkt = NULL;
				ip_free(pkt);			/* throw away old packet */
	        udp_counts[0]++;
			tm_stop(p_udpctl->udp_tm);
			tm_set(UDP_KEEPPKT,udp_droppacket,p_udpctl->udp_tm);
			(p_udpctl->upcall)(p_udpctl->handle,(char *)p_udp+sizeof(UDP), len-(int)sizeof(UDP));
		}
		else
		{
			p_udpctl->fhost = 0L;
			p_udpctl->fgn_port = 0;
			p_udpctl->data_len = 0;
			p_udpctl->data = NULL;
			p_udpctl->pkt = NULL;
			ip_free(pkt);
		}
		return(TRUE);
	}
		ip_free(pkt);
		return(FALSE);


	/* what a crock. check if the packet was sent to an ip
		broadcast address. If it was, don't send a destination
		unreachable.
	*/

	if((p_ip->dst_inaddr == 0xffffffffL)) /* Physical cable broadcast addr*/
	{
		ip_free(pkt);
		return(FALSE);
	}

	/* send destination unreachable */
	icmp_dstun(p_ip->src_inaddr,p_ip,ICMP_DSTPORT);

	ip_free(pkt);
	return(FALSE);
}



/* This routine drops a UDP packet from the udp-port, if nobody */
/* fetched the data within the UDP_KEEPPKT timeout 				*/

void udp_droppacket(TIMER tm)
{
UDP_CTL *p_udpctl;

	p_udpctl = udp_list;
	while(p_udpctl)
	{
		if(p_udpctl->udp_tm == tm)
		{
			if(p_udpctl->pkt)
			{
#ifdef DEBUGDROP
printf("UDP: dropping packet from %8lx.%u\n",p_udpctl->fhost,p_udpctl->fgn_port);
#endif
				p_udpctl->fhost = 0L;
				p_udpctl->fgn_port = 0;
				p_udpctl->data_len = 0;
				p_udpctl->data = NULL;
				ip_free(p_udpctl->pkt);
				p_udpctl->pkt = NULL;
				p_udpctl->udp_err = UDP_MISS;
			}
		}
		p_udpctl = p_udpctl->next;
	}
	return;
}


UDP_CTL *udp_getctl(u_short udp)
{

	if(udp >= udp_tablen) return(NULL);
	return(udp_tab[udp]);
}



/* This routine drops a udp packet from its port */

int udp_free(u_short udp)
{
UDP_CTL *p_udpctl;

	if(udp >= udp_tablen) return(-1);
	p_udpctl = udp_tab[udp];
	if(p_udpctl && p_udpctl->pkt)
	{
		p_udpctl->fhost = 0L;
		p_udpctl->fgn_port = 0;
		p_udpctl->data_len = 0;
		p_udpctl->data = NULL;
		tm_stop(p_udpctl->udp_tm);
		ip_free(p_udpctl->pkt);
		p_udpctl->pkt = NULL;
		p_udpctl->udp_err = UDP_OK;
		return(udp);
	}
	return(-1);
}



/* This routine handles incoming UDP destination unreachable packets.
	They're handed to it by the internet layer. It demultiplexes
	the incoming packet based on the local port and upcalls the
	appropriate routine. */

int udp_du_handler(IP *p_ip)
{
register UDP *p_udp;
register UDP_CTL *p_udpctl;

Bconws("UDP destunreachable\r");
	p_udp = udp_head(p_ip);

	for(p_udpctl = udp_list; p_udpctl; p_udpctl = p_udpctl->next)
	{
		if(p_udpctl->lcl_port && (p_udpctl->lcl_port != p_udp->src_port))
			continue;

		p_udpctl->udp_err = UDP_NORECV;
		if(p_udpctl->upcall)
			p_udpctl->upcall(p_udpctl->handle,NULL,UDP_NORECV);
		return(TRUE);
	}
	return(FALSE);
}


/* This routine removes a udp-port from the udp connections list	*/
/* and frees all related memory 									*/


int udp_close(u_short udp)
{
UDP_CTL *p_udpctl;
UDP_CTL **pp_udpctl;

	if (udp >= udp_tablen) return(-1);
	pp_udpctl = &udp_list;
	p_udpctl = udp_list;
	while(p_udpctl)
	{
		if(p_udpctl == udp_tab[udp])	/* found */
		{
#ifdef DEBUGCLS
printf("UDP: close port %u\n",p_udpctl->lcl_port);
#endif
			*pp_udpctl = p_udpctl->next;	/* unlink from list */
			if(p_udpctl->pkt)
			{
				tm_stop(p_udpctl->udp_tm);
				ip_free(p_udpctl->pkt);
			}
			buf_free(udp_buffers,(char *)p_udpctl);
			udp_tab[udp] = NULL;
			return(udp);
		}
		pp_udpctl = &(p_udpctl->next);
		p_udpctl = p_udpctl->next;
	}
	return(0);
}
	
	
int udp_open(u_short lcl_port,UDP_UPCALL upcall)
{
int i;
UDP_CTL *p_udpctl;

	if(!lcl_port)
		lcl_port = udp_newport();
#ifdef DEBUGOPN
printf("UDP: open port %u\n",lcl_port);
#endif	
	for(p_udpctl = udp_list; p_udpctl; p_udpctl = p_udpctl->next)
	{
		if(p_udpctl->lcl_port == lcl_port)
		{
#ifdef DEBUGOPN
printf("UDP: port %u already in use\n",lcl_port);
#endif	
			return -1;
		}
	}
	
	for(i=0; i < udp_tablen; i++)
		if(!udp_tab[i]) break;
	if(i == udp_tablen)
	{
#ifdef DEBUGOPN
printf("UDP: port table full\n");
#endif	
		return(-1);
	}


	p_udpctl = (UDP_CTL *)buf_alloc(udp_buffers,sizeof(UDP_CTL));
	if(!p_udpctl)
	{
#ifdef DEBUGOPN
printf("UDP: out of buffers\n");
#endif	
		return(-1);
	}
	udp_tab[i] = p_udpctl;

	p_udpctl->next = udp_list;
	udp_list = p_udpctl;
	
	
	p_udpctl->lcl_port = lcl_port;		/* fill in connection info */
	p_udpctl->fgn_port = 0;
	p_udpctl->fhost = 0L;
	p_udpctl->upcall = upcall;
	p_udpctl->pkt = NULL;
	p_udpctl->data_len = 0;
	p_udpctl->data = NULL;
	p_udpctl->udp_err = UDP_OK;
	
	return(i);
}


int udp_write(u_short udp, char *data, u_short len, INADDR fhost, u_short port)
{
register UDP	*p_udp;
TCP_PSEUDO 		 udp_ph;
PACKET 			*pkt;
UDP_CTL			*p_udpctl;
register int 	 udplen;
u_short 		 csum;
int 			 ret;

	if(udp >= udp_tablen || !udp_tab[udp] ||
	   !fhost || !port || !data)
		return(-1);
		
	p_udpctl = udp_tab[udp];
	if(p_udpctl->udp_err == UDP_NORECV) return(-1);

	udplen = len + (int)sizeof(UDP);
	pkt = ip_alloc(udplen,0);
	if(!pkt) return(-1);

	p_udp = (UDP *)ip_data(pkt);

	p_udp->length = udplen;
	p_udp->src_port = p_udpctl->lcl_port;
	p_udp->dst_port = port;
	p_udp->chksum = 0;

	p_udpctl->fgn_port = port;
	p_udpctl->fhost = fhost;

	udp_ph.src = ip_myaddr();
	udp_ph.dst = fhost;
	udp_ph.protocol = IP_UDP;
	udp_ph.length = udplen;
	
	if(len) memcpy(udp_data(p_udp),data,len);		/* copy data */
	
	csum = ~chksum((u_short *)&udp_ph,(int)sizeof(TCP_PSEUDO),0);
	p_udp->chksum = chksum((u_short *)p_udp,udp_ph.length,csum);

	udp_counts[1]++;

	ret = ip_send(IP_UDP,pkt,udp_ph.length,fhost);

#ifdef DEBUGWR
printf("UDP: pkt[%d] %u to   %08lx.%u\n",udp_ph.length,p_udpctl->lcl_port,fhost,port);
if(ret < 0) printf("ip_send: network error %d\n",ret);
#endif
	ip_free(pkt);
	return(len);
}


u_short udp_newport(void)
{
static u_short udp_actport = 0;

	if(!udp_actport)
	{
		udp_actport = (u_short)clock();
	}
	else
		udp_actport++;
	if(udp_actport < 1200)
		udp_actport += 1200;
	return(udp_actport);
}