#include "packet_attribute_parser.h"
#include "core.h"

#include <stdio.h>
#include <string.h>

u_int16_t tcp_checksum(struct packet_attribute attr);
u_int16_t ip_checksum(struct packet_attribute attr);

void * default_layer1_write_module(struct connection * connection, struct packet_attribute attr)
{
	if(attr.tcp) {
		attr.tcp_header->check = 0x0000;
		attr.tcp_header->check = htons(tcp_checksum(attr));
	}

	if(attr.ip) {
		attr.ip_header->check = 0x0000;
		attr.ip_header->check = ip_checksum(attr);
	}

	return NULL;
}

void * default_layer1_read_module(struct connection * connection, struct packet_attribute attr)
{
	return NULL;
}

void * default_layer1_writeback_module(struct connection * connection, struct packet_attribute attr)
{
	if(attr.tcp) {
		attr.tcp_header->check = 0x0000;
		attr.tcp_header->check = htons(tcp_checksum(attr));
	}

	if(attr.ip) {
		attr.ip_header->check = 0x0000;
		attr.ip_header->check = ip_checksum(attr);
	}

	return NULL;
}

/*
 * Purpose: calculate tcp checksum
 * Arguments: packet_attribute
 * Returns: checksum
 */
u_int16_t tcp_checksum(struct packet_attribute attr)
{
    u_int32_t sum;
    char * saddr, * daddr, * tcp;
    int i, tcp_len, padding;

    sum = 0;
    padding = 0;

    saddr = (char *) &(attr.ip_header->saddr);
    daddr = (char *) &(attr.ip_header->daddr);
    tcp = (char *) attr.tcp_header;

    tcp_len = ntohs(attr.ip_header->tot_len) - (attr.ip_header->ihl<<2);

    if(tcp_len & 1) {
        tcp[tcp_len] = 0;
        padding = 1;
    }

    for(i=0; i<tcp_len+padding; i+=2) {
        sum += ((tcp[i]<<8)&0xff00) + (tcp[i+1]&0xff);
    }
    for(i=0; i<4; i+=2) {
        sum += ((saddr[i]<<8)&0xff00) + (saddr[i+1]&0xff);
    }
    for(i=0; i<4; i+=2) {
        sum += ((daddr[i]<<8)&0xff00) + (daddr[i+1]&0xff);
    }

    sum += 6 + tcp_len;

    while(sum>>16)
        sum = (sum & 0xffff) + (sum >> 16);

    return (u_int16_t) ~sum;
}

u_int16_t ip_checksum(struct packet_attribute attr)
{
	u_int16_t * addr;
	int count;
	u_int32_t sum;

	addr = (u_int16_t *) attr.ip_header;
	count = (attr.ip_header->ihl) << 2;

	sum = 0;
	while(count >= 2) {
		sum += *addr;
		addr++;
		count -= 2;
	}

	if(count) sum += (*addr) & 0xff00;

	while(sum >> 16) {
		sum = (sum & 0xffff) + (sum >> 16);
	}

	return (u_int16_t) ~sum; 
}
