#include "packet_attribute_parser.h"
#include "core.h"
#include "err.h"
#include "hash_map.h"

#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <sys/types.h>
#include <openssl/des.h>
#include <openssl/md5.h>
#include <openssl/md4.h>


struct layer7_state_bucket {
#ifndef STATE_BUCKET_KEY_LENGTH
#define STATE_BUCKET_KEY_LENGTH 100
#endif
	char key[STATE_BUCKET_KEY_LENGTH];

	char server_challenge[8];
};

#define IN
#define OUT

void unicode(char * src, int src_len, char * dst);
void DES(IN unsigned char * K, IN char * D, OUT char * dst);
void C_MD5(IN char * data, int len, OUT char * dst);
void C_MD4(IN char * data, int len, OUT char * dst);
void hex_dump(char * data,  int len);
void compute_response(char * server_challenge, char * client_challenge, char * dst);
struct layer7_state_bucket * get_layer7_state_bucket(struct connection * connection);


void * default_layer7_write_module(struct connection * connection, struct packet_attribute attr)
{
	if(!attr.ip || !attr.tcp) return NULL;

//	if((connection->dport == 139 || connection->dport == 445) && (attr.tcp_data_len > 34)) {
    if((ntohs(attr.tcp_header->dest) == 139 || ntohs(attr.tcp_header->dest) == 445) && (attr.tcp_data_len > 34)) {
		struct layer7_state_bucket * bucket;
		char * tcp_payload;
		u_int8_t * cmd;

		if((bucket = get_layer7_state_bucket(connection)) == NULL) { warnx("default_layer7_read_module: get_layer7_state_bucket"); return NULL; }

		tcp_payload = ((char *) attr.tcp_header) + attr.tcp_header_len;

		cmd = (u_int8_t *) (tcp_payload + 8);
		if(*cmd == 0x73) {
			/* depend on OS */
			char * smb_content = (char *) ((char *) tcp_payload + 36);
			int i;

			for(i=0; i< (int) (attr.tcp_data_len-36-8); ++i) {
				static char ntlmssp[] = "\x4e\x54\x4c\x4d\x53\x53\x50\x00";
				if(memcmp(smb_content+i, ntlmssp, 8) == 0) {
					char * security_blob = smb_content+i;
					u_int32_t * msg_type = (u_int32_t *) (security_blob + 8);   //for windows 2000
					if(*msg_type == 0x00000003) {
						char * client_challenge = security_blob + 150;
						char * response = security_blob + 174;
						compute_response((char *) bucket->server_challenge, client_challenge, response);
//						compute_response((char *) &bucket->server_challenge, client_challenge, response);

						{
							char * cc = bucket->server_challenge;
							char * ss = response;
							warnx("challenge = %x %x %x %x %x %x %x %x", cc[0]&0xff,cc[1]&0xff,cc[2]&0xff,cc[3]&0xff,cc[4]&0xff,cc[5]&0xff,cc[6]&0xff,cc[7]&0xff);
							warnx("[%s] response = %x %x %x %x %x %x %x %x", bucket->key, ss[0]&0xff,ss[1]&0xff,ss[2]&0xff,ss[3]&0xff,ss[4]&0xff,ss[5]&0xff,ss[6]&0xff,ss[7]&0xff);
						}

					}
					break;
				}
			}
		}
	}

	return NULL;
}

void * default_layer7_read_module(struct connection * connection, struct packet_attribute attr)
{
	if(!attr.ip || !attr.tcp) return NULL;

//	if((connection->dport == 139 || connection->dport == 445) && (attr.tcp_data_len > 34)) {
    if((ntohs(attr.tcp_header->source) == 139 || ntohs(attr.tcp_header->source) == 445) && (attr.tcp_data_len > 34)) {
		struct layer7_state_bucket * bucket;
		char * tcp_payload;
		u_int8_t * cmd;

		if((bucket = get_layer7_state_bucket(connection)) == NULL) { warnx("default_layer7_read_module: get_layer7_state_bucket"); return NULL; }

		tcp_payload = ((char *) attr.tcp_header) + (attr.tcp_header->doff<<2);

		cmd = (u_int8_t *) (((char *) tcp_payload) + 8);
		if(*cmd == 0x73) {
			/* depend on OS */
			char * smb_content = (char *) ((char *) tcp_payload + 36);
			int i;

			for(i=0; i< (int) (attr.tcp_data_len-36-8); ++i) {
				static char ntlmssp[] = "\x4e\x54\x4c\x4d\x53\x53\x50\x00";
				if(memcmp(smb_content+i, ntlmssp, 8) == 0) {
					char * security_blob = smb_content + i;
					u_int32_t * msg_type = (u_int32_t *) (security_blob + 8);   //for windows 2000
					if(*msg_type == 0x00000002) {
						u_int32_t * server_challenge = (u_int32_t *) (security_blob + 24);
						memcpy(bucket->server_challenge, server_challenge, 8);

						{
							char * cc = bucket->server_challenge;
							warnx("[%s] challenge = %x %x %x %x %x %x %x %x", bucket->key, cc[0]&0xff,cc[1]&0xff,cc[2]&0xff,cc[3]&0xff,cc[4]&0xff,cc[5]&0xff,cc[6]&0xff,cc[7]&0xff);
						}
					}
					break;
				}
			}
		}
	}

	return NULL;
}

void * default_layer7_writeback_module(struct connection * connection, struct packet_attribute attr)
{
	return NULL;
}

struct layer7_state_bucket * get_layer7_state_bucket(struct connection * connection)
{
	char key[STATE_BUCKET_KEY_LENGTH];
	struct layer7_state_bucket * bucket;

	sprintf(key, "l7_%s:%d_%s:%d", connection->session->sip, connection->sport, connection->session->dip, connection->dport);

	if((bucket = hash_map_get(key)) == NULL) {
		if((bucket = (struct layer7_state_bucket *) malloc(sizeof(struct layer7_state_bucket))) == NULL) { warn("get_layer7_state_bucket: malloc"); return NULL; }
		memset(bucket, 0, sizeof(struct layer7_state_bucket));
		strncpy(bucket->key, key, STATE_BUCKET_KEY_LENGTH);

		if(hash_map_put(key, bucket) == -1) { warnx("get_layer7_state_bucket: hash_map_put"); return NULL; }
	}

	return bucket;
}

/* challenge-response */
void compute_response(char * server_challenge, char * client_challenge, char * dst)
{
	char passwd[] = "";
	char unicode_passwd[BUFSIZ];

	char challenge[16];
	char challenge_md5[16];

	unsigned char key[21];
	char data[8];

	memset(key, 0, 21);
	unicode(passwd, 0, unicode_passwd);
	MD4((unsigned char *) unicode_passwd, 0, key);

	memcpy(challenge, server_challenge, 8);
	memcpy(challenge+8, client_challenge, 8);
	MD5((unsigned char *) challenge, 16, (unsigned char *) challenge_md5);

	memcpy(data, challenge_md5, 8);

	DES(key, data, dst);
	DES(key+7, data, dst+8);
	DES(key+14, data, dst+16);
}

void unicode(char * src, int len, char * dst)
{
	int i;

	for(i=0; i<len; ++i){
		dst[2*i] = src[i];
		dst[2*i +1] = 0;
	}
}

/* K 7-byte
 * D 8-byte
 * dst 8-byte
 */
void DES(IN unsigned char * K, IN char * D, OUT char * dst)
{
	DES_cblock key, plaintext;
	DES_key_schedule sched;

	memcpy(&plaintext, D, 8);

	((unsigned char*)key)[0] =  K[0];
	((unsigned char*)key)[1] = (K[0] << 7) | (K[1] >> 1);
	((unsigned char*)key)[2] = (K[1] << 6) | (K[2] >> 2);
	((unsigned char*)key)[3] = (K[2] << 5) | (K[3] >> 3);
	((unsigned char*)key)[4] = (K[3] << 4) | (K[4] >> 4);
	((unsigned char*)key)[5] = (K[4] << 3) | (K[5] >> 5);
	((unsigned char*)key)[6] = (K[5] << 2) | (K[6] >> 6);
	((unsigned char*)key)[7] = (K[6] << 1);
	DES_set_odd_parity(&key);
	DES_set_key_unchecked(&key, &sched);
	DES_ecb_encrypt((DES_cblock *) &plaintext, (DES_cblock *) dst, &sched, DES_ENCRYPT);
}

void C_MD5(IN char * data, int len, OUT char * dst)
{
	MD5_CTX context;

	memset(dst, 0, 16);
	MD5_Init(&context);
	MD5_Update(&context, data, len);
	MD5_Final((unsigned char *) dst, &context);
}

void C_MD4(IN char * data, int len, OUT char * dst)
{
	MD4_CTX context;

	memset(dst, 0, 16);
	MD4_Init(&context);
	MD4_Update(&context, data, len);
	MD4_Final((unsigned char *) dst, &context);
}

void hex_dump(IN char * data, int len)
{
	int i;

	for(i=0; i<len; ++i)
		printf("%02x", data[i] & 0xff);
	printf("\n");
}

