#include "ScriptEngine.h"
#include "CleoInstance.h"
#include "DebugLog.h"
#include "GameVersionManager.h"
#include "FileEnumerator.h"
#include <direct.h>
#include <fstream>
#include <memory>
#include "crc32.h"

namespace CLEO 
{
	void THISCALL (* AddScriptToQueue)(CScriptThread *, CScriptThread **queue);
	void THISCALL(* RemoveScriptFromQueue)(CScriptThread *, CScriptThread **queue);
	void THISCALL(* StopScript)(CScriptThread *);
	char THISCALL(* ScriptOpcodeHandler00)(CScriptThread *, WORD opcode);
	void THISCALL (* GetScriptParams)(CScriptThread *, int count);
	void THISCALL (* TransmitScriptParams)(CScriptThread *, CScriptThread *);
	void THISCALL (* SetScriptParams)(CScriptThread *, int count);
	void THISCALL (* SetScriptCondResult)(CScriptThread *, bool);
	SCRIPT_VAR * THISCALL (* GetScriptParamPointer1)(CScriptThread *);
	void THISCALL (* GetScriptStringParam)(CScriptThread *, char* buf, BYTE len);
	SCRIPT_VAR * THISCALL (* GetScriptParamPointer2)(CScriptThread *, int __unused__);

	void CDECL(* InitScm)(void);
	void CDECL(* SaveScmData)(void);
	void CDECL(* LoadScmData)(void);
	
	DWORD* GameTimer;
	extern "C" {
		SCRIPT_VAR *opcodeParams;
		SCRIPT_VAR *missionLocals;
		CScriptThread *staticThreads;
	}
	BYTE *scmBlock;
	BYTE *MissionLoaded;
	BYTE *missionBlock;
	BOOL *onMissionFlag;

	CScriptThread **inactiveThreadQueue, **activeThreadQueue;

	// called each time user starts new game or loads a safe
	void CDECL OnInitScm1(void)
	{
		TRACE("Initializing SCM #1");
		GetThisInstance().textManager.ClearDynamicFxts();
		GetThisInstance().opcodeSystem.FinalizeScriptObjects();
		GetThisInstance().scriptEngine.RemoveAllCustomScripts();
		GetThisInstance().soundSystem.UnloadAllStreams();
		InitScm();
		GetThisInstance().scriptEngine.LoadCustomScripts();
	}

	// called only 1 time when user first time starts new game or loads a safe
	void CDECL OnInitScm2(void)
	{
		TRACE("Initializing SCM #2");
		GetThisInstance().textManager.ClearDynamicFxts();
		GetThisInstance().opcodeSystem.FinalizeScriptObjects();
		GetThisInstance().scriptEngine.RemoveAllCustomScripts();
		GetThisInstance().soundSystem.UnloadAllStreams();
		InitScm();
		GetThisInstance().scriptEngine.LoadCustomScripts();
	}

	// called each time only when user loads a safe
	void CDECL OnInitScm3(void)
	{
		TRACE("Initializing SCM #3");
		GetThisInstance().textManager.ClearDynamicFxts();
		GetThisInstance().opcodeSystem.FinalizeScriptObjects();
		GetThisInstance().scriptEngine.RemoveAllCustomScripts();
		GetThisInstance().soundSystem.UnloadAllStreams();
		InitScm();
		GetThisInstance().scriptEngine.LoadCustomScripts(true);
	}
	
	extern "C" void STDCALL opcode_004E(CScriptThread *thread)
	{
		if (thread->IsCustom)
		{
			auto cs = reinterpret_cast<CCustomScript *>(thread);
			if (!cs->missionFlag)
			{
				TRACE("[004E] Incorrect usage of opcode.");
				return;
			}
			GetThisInstance().scriptEngine.RemoveCustomScript(cs);
			delete cs;
		}
		RemoveScriptFromQueue(thread, activeThreadQueue);
		AddScriptToQueue(thread, inactiveThreadQueue);
		StopScript(thread);
	}
	
	extern "C" void opcode_004E_hook(void) __attribute__((weak));
	asm volatile ("_opcode_004E_hook:\n"
		"	pushl %esi\n"
		"	call _opcode_004E@4\n"
		"	popl %edi\n"				// function epilogue
		"	movb $1,%al\n"
		"	popl %esi\n"
		"	movl 20(%esp),%ecx\n"
		"	movl %ecx,%fs:0\n"
		"	addl $32, %esp\n"
		"	ret $4\n"
		);

	void CDECL OnNewGame(void)
	{	
		static struct REVERSED CGangWeapons {
			BYTE _f0;
			BYTE _f1; // -
			DWORD weapon1;
			DWORD weapon2;
			DWORD weapon3;
		} *gangWeapons((CGangWeapons *)0xC0B870);	// 1.01 eu specific
		TRACE("New game started");
		gangWeapons[0].weapon1 = 22;
		gangWeapons[0].weapon2 = 28;
		gangWeapons[0].weapon3 = 0;

		gangWeapons[1].weapon1 = 22;
		gangWeapons[1].weapon2 = 0;
		gangWeapons[1].weapon3 = 0;

		gangWeapons[2].weapon1 = 22;
		gangWeapons[2].weapon2 = 0;
		gangWeapons[2].weapon3 = 0;

		gangWeapons[4].weapon1 = 24;
		gangWeapons[4].weapon2 = 28;
		gangWeapons[4].weapon3 = 0;

		gangWeapons[5].weapon1 = 24;
		gangWeapons[5].weapon2 = 0;
		gangWeapons[5].weapon3 = 0;

		gangWeapons[6].weapon1 = 22;
		gangWeapons[6].weapon2 = 30;
		gangWeapons[6].weapon3 = 0;

		gangWeapons[7].weapon1 = 22;
		gangWeapons[7].weapon2 = 28;
		gangWeapons[7].weapon3 = 0;
		GetThisInstance().textManager.ClearDynamicFxts();
		GetThisInstance().opcodeSystem.FinalizeScriptObjects();
		GetThisInstance().scriptEngine.RemoveAllCustomScripts();
		GetThisInstance().soundSystem.UnloadAllStreams();
		GetThisInstance().scriptEngine.LoadCustomScripts();
	}

	void CDECL OnLoadScmData(void)
	{
		TRACE(__PRETTY_FUNCTION__ );
		LoadScmData();
	}

	void CDECL OnSaveScmData(void)
	{
		TRACE(__PRETTY_FUNCTION__ );
		GetThisInstance().scriptEngine.SaveState();
		GetThisInstance().scriptEngine.UnregisterAllScripts();
		SaveScmData();
		GetThisInstance().scriptEngine.ReregisterAllScripts();
	}

	struct CleoSafeHeader
	{
		const static unsigned sign;
		unsigned signature;
		unsigned n_saved_threads;
		unsigned n_stopped_threads;
	};

	const unsigned CleoSafeHeader::sign = 0x31345653;

	struct ThreadSavingInfo
	{
		unsigned long hash;
		SCRIPT_VAR tls[32];
		unsigned timers[2];
		bool condResult;
		unsigned sleepTime;
		unsigned short logicalOp;
		bool notFlag;
		ptrdiff_t ip_diff;
		char threadName[8];

		ThreadSavingInfo(CCustomScript *cs) :
			hash(cs->checksum), condResult(cs->condResult),
			logicalOp(cs->logicalOp), notFlag(cs->notFlag), ip_diff(cs->ip - cs->baseIp)
		{
			sleepTime = cs->wakeTime >= *GameTimer ? 0 : cs->wakeTime - *GameTimer;
			std::copy(cs->tls, cs->tls + 32, tls);
			std::copy(cs->timers, cs->timers + 2, timers);
			std::copy(cs->threadName, cs->threadName + 8, threadName);
		}

		void Apply(CCustomScript *cs)
		{
			cs->checksum = hash;
			std::copy(tls, tls + 32, cs->tls);
			std::copy(timers, timers + 2, cs->timers);
			cs->condResult = condResult;
			cs->wakeTime = *GameTimer + sleepTime;
			cs->logicalOp = logicalOp;
			cs->notFlag = notFlag;
			cs->ip = cs->baseIp + ip_diff;
			std::copy(threadName, threadName + 8, cs->threadName);
			cs->saving_enabled = true;
		}

		ThreadSavingInfo() { }
	};

	SCRIPT_VAR ScriptEngine::CleoVariables[0x400];

	template<typename T>
	void inline ReadBinary(std::istream& s, T& buf)
	{ 
		s.read(reinterpret_cast<char *>(&buf), sizeof(T)); 
	}

	template<typename T>
	void inline ReadBinary(std::istream& s, T *buf, size_t size) 
	{ 
		s.read(reinterpret_cast<char *>(buf), sizeof(T) * size); 
	}
	
	template<typename T>
	void inline WriteBinary(std::ostream& s, const T& data) 
	{ 
		s.write(reinterpret_cast<const char *>(&data), sizeof(T)); 
	}

	template<typename T>
	void inline WriteBinary(std::ostream& s, const T*data, size_t size) 
	{ 
		s.write(reinterpret_cast<const char *>(data), sizeof(T) * size); 
	}
	
	void ScriptEngine::inject(CodeInjector& inj)
	{
		TRACE("Injecting ScriptEngine...");
		GameVersionManager& gvm = GetThisInstance().versionManager;
		AddScriptToQueue = gvm.TranslateMemoryAddress(MA_ADD_SCRIPT_TO_QUEUE_FUNCTION);
		RemoveScriptFromQueue = gvm.TranslateMemoryAddress(MA_REMOVE_SCRIPT_FROM_QUEUE_FUNCTION);
		StopScript = gvm.TranslateMemoryAddress(MA_STOP_SCRIPT_FUNCTION);
		ScriptOpcodeHandler00 = gvm.TranslateMemoryAddress(MA_SCRIPT_OPCODE_HANDLER0_FUNCTION);
		GetScriptParams = gvm.TranslateMemoryAddress(MA_GET_SCRIPT_PARAMS_FUNCTION);
		TransmitScriptParams = gvm.TranslateMemoryAddress(MA_TRANSMIT_SCRIPT_PARAMS_FUNCTION);
		SetScriptParams = gvm.TranslateMemoryAddress(MA_SET_SCRIPT_PARAMS_FUNCTION);
		SetScriptCondResult = gvm.TranslateMemoryAddress(MA_SET_SCRIPT_COND_RESULT_FUNCTION);
		GetScriptParamPointer1 = gvm.TranslateMemoryAddress(MA_GET_SCRIPT_PARAM_POINTER1_FUNCTION);
		GetScriptStringParam = gvm.TranslateMemoryAddress(MA_GET_SCRIPT_STRING_PARAM_FUNCTION);
		GetScriptParamPointer2 = gvm.TranslateMemoryAddress(MA_GET_SCRIPT_PARAM_POINTER2_FUNCTION);

		InitScm = gvm.TranslateMemoryAddress(MA_INIT_SCM_FUNCTION);
		SaveScmData = gvm.TranslateMemoryAddress(MA_SAVE_SCM_DATA_FUNCTION);
		LoadScmData = gvm.TranslateMemoryAddress(MA_LOAD_SCM_DATA_FUNCTION);

		GameTimer = gvm.TranslateMemoryAddress(MA_GAME_TIMER);
		opcodeParams = gvm.TranslateMemoryAddress(MA_OPCODE_PARAMS);
		missionLocals = gvm.TranslateMemoryAddress(MA_SCM_BLOCK);
		scmBlock = gvm.TranslateMemoryAddress(MA_MISSION_LOCALS);
		MissionLoaded = gvm.TranslateMemoryAddress(MA_MISSION_LOADED);
		missionBlock = gvm.TranslateMemoryAddress(MA_MISSION_BLOCK);
		onMissionFlag = gvm.TranslateMemoryAddress(MA_ON_MISSION_FLAG);

		inactiveThreadQueue =  gvm.TranslateMemoryAddress(MA_INACTIVE_THREAD_QUEUE);
		activeThreadQueue =  gvm.TranslateMemoryAddress(MA_ACTIVE_THREAD_QUEUE);
		staticThreads =  gvm.TranslateMemoryAddress(MA_STATIC_THREADS);

		if (gvm.GetGameVersion() == GV_EU11)
		{
			inj.ReplaceFunction(OnInitScm3,	gvm.TranslateMemoryAddress(MA_CALL_INIT_SCM3));
			inj.InjectFunction(OnNewGame, 0x5DEEA0);	// GV_EU11 specific
		}
		else
		{
			inj.ReplaceFunction(OnInitScm1,	gvm.TranslateMemoryAddress(MA_CALL_INIT_SCM1));
			inj.ReplaceFunction(OnInitScm2,	gvm.TranslateMemoryAddress(MA_CALL_INIT_SCM2));
			inj.ReplaceFunction(OnInitScm3,	gvm.TranslateMemoryAddress(MA_CALL_INIT_SCM3));
		}
		inj.ReplaceFunction(OnLoadScmData, gvm.TranslateMemoryAddress(MA_CALL_LOAD_SCM_DATA));
		inj.ReplaceFunction(OnSaveScmData, gvm.TranslateMemoryAddress(MA_CALL_SAVE_SCM_DATA));
		inj.InjectFunction(&opcode_004E_hook, gvm.TranslateMemoryAddress(MA_OPCODE_004E));
	}

	void ScriptEngine::LoadCustomScripts(bool load_mode)
	{
		char safe_name[MAX_PATH];
		sprintf(safe_name, "./cleo/cleo_saves/cs%d.sav", menuManager->SaveNumber);
		CleoSafeHeader safe_header;
		ThreadSavingInfo *safe_info = nullptr;
		unsigned long *stopped_info = nullptr;
		std::unique_ptr<ThreadSavingInfo[]> safe_info_utilizer;
		std::unique_ptr<unsigned long[]> stopped_info_utilizer;
		safe_header.n_saved_threads = safe_header.n_stopped_threads = 0;
		if (load_mode)
		{
			// load cleo saving file
			try
			{
				TRACE("Loading cleo safe %s", safe_name);
				std::ifstream ss(safe_name, std::ios::binary);
				ss.exceptions(std::ios::eofbit | std::ios::badbit  
					| std::ios::failbit);
				ReadBinary(ss, safe_header);
				if (safe_header.signature != CleoSafeHeader::sign) 
					throw std::runtime_error("Invalid file format");
				safe_info = new ThreadSavingInfo[safe_header.n_saved_threads];
				safe_info_utilizer.reset(safe_info);
				stopped_info = new unsigned long[safe_header.n_stopped_threads];
				stopped_info_utilizer.reset(stopped_info);
				ReadBinary(ss, CleoVariables, 0x400);
				ReadBinary(ss, stopped_info, safe_header.n_stopped_threads);
				ReadBinary(ss, safe_info, safe_header.n_saved_threads);
				for (size_t i = 0; i < safe_header.n_stopped_threads; ++i)
					stoppedThreadHashes.insert(stopped_info[i]);
				TRACE("Finished. Loaded %u cleo variables, %u saved threads info, %u stopped threads info", 
					0x400, safe_header.n_saved_threads, safe_header.n_stopped_threads);
			}
			catch (std::exception& ex)
			{
				TRACE("Loading of cleo safe %s failed: %s", safe_name, ex.what());
				safe_header.n_saved_threads = safe_header.n_stopped_threads = 0;
			}
		}
		char cwd[MAX_PATH];
		_getcwd(cwd, sizeof(cwd));
		chdir(cleo_dir);
		enumerate_files(cs_mask, [this,stopped_info, safe_info, &safe_header]
			(const char *filename)
			{
				auto cs = new CCustomScript(filename);
				if (!cs->OK) 
				{
					TRACE("Loading of custom script %s failed", filename);
					delete cs;
					return;
				}
				// check whether the script is in stop-list
				if (stopped_info)
				{
					for (size_t i = 0; i < safe_header.n_stopped_threads; ++i)
						if (stopped_info[i] == cs->checksum) 
						{
							TRACE("Custom script %s found in the stop-list", filename);
							stoppedThreadHashes.insert(stopped_info[i]);
							delete cs;
							return;
						}
				}
				// check whether the script is in safe-list
				if (safe_info)
				{
					for (size_t i = 0; i < safe_header.n_saved_threads; ++i)
						if (safe_info[i].hash == cs->checksum) 
						{
							TRACE("Custom script %s found in the safe-list", filename);
							safe_info[i].Apply(cs);
							break;
						}
				}
				AddCustomScript(cs);
			});
		chdir(cwd);
	}

	void ScriptEngine::SaveState()
	{
		try
		{
			std::list<CCustomScript *> savedThreads;
			std::for_each(customScripts.begin(), customScripts.end(),
				[this,&savedThreads](CCustomScript *cs)
				{
					if (cs->saving_enabled) savedThreads.push_back(cs);
				});
			CleoSafeHeader header = { CleoSafeHeader::sign, savedThreads.size(), 
				stoppedThreadHashes.size() };
			char safe_name[MAX_PATH];
			sprintf(safe_name, "./cleo/cleo_saves/cs%d.sav", menuManager->SaveNumber);
			TRACE("Saving script engine state to the file %s", safe_name);
			std::ofstream ss(safe_name, std::ios::binary);
			ss.exceptions(std::ios::failbit | std::ios::badbit);
			WriteBinary(ss, header);
			WriteBinary(ss, CleoVariables, 0x400);
			std::for_each(savedThreads.begin(), savedThreads.end(),
				[&savedThreads, &ss](CCustomScript *cs)
				{
					ThreadSavingInfo savingInfo(cs);
					WriteBinary(ss, savingInfo);
				});
			std::for_each(stoppedThreadHashes.begin(), stoppedThreadHashes.end(),
				[&ss](unsigned long hash)
				{
					WriteBinary(ss, hash);
				});
			TRACE("Done. Saved %u cleo variables, %u saved threads, %u stopped threads", 
				0x400, header.n_saved_threads, header.n_stopped_threads);
		}
		catch (std::exception& ex)
		{
			TRACE("Saving failed. %s", ex.what());
		}
	}

	CCustomScript *ScriptEngine::FindThreadByName(const char *name)
	{
		if (_stricmp(name, customMission->threadName) == 0) return customMission;
		for (auto it = customScripts.begin(); it != customScripts.end(); ++it)
		{
			auto cs = *it;
			if (_stricmp(name, cs->threadName) == 0)
				return cs;
		}
		return nullptr;
	}

	void ScriptEngine::AddCustomScript(CCustomScript *cs)
	{
		if (cs->missionFlag)
		{
			TRACE("Registering custom mission named %s", cs->threadName);
			customMission = cs;
		}
		else
		{
			TRACE("Registering custom script named %s", cs->threadName);
			customScripts.push_back(cs);
		}
		AddScriptToQueue(cs, activeThreadQueue);
		cs->isActive = true;
	}

	void ScriptEngine::RemoveCustomScript(CCustomScript *cs)
	{
		if (cs == customMission)
		{
			TRACE("Unregistering custom mission named %s", cs->threadName);
			RemoveScriptFromQueue(customMission, activeThreadQueue);
			customMission->isActive = false;
			customMission = nullptr;
		}
		else
		{
			if (cs->saving_enabled)
			{
				stoppedThreadHashes.insert(cs->checksum);
				TRACE("Stopping custom script named %s", cs->threadName);
			}
			else
				TRACE("Unregistering custom script named %s", cs->threadName);
			customScripts.remove(cs);
			RemoveScriptFromQueue(cs, activeThreadQueue);
			cs->isActive = false;
		}
	}

	void ScriptEngine::RemoveAllCustomScripts(void)
	{
		stoppedThreadHashes.clear();
		std::for_each(customScripts.begin(), customScripts.end(),
			[this](CCustomScript *cs)
			{
				TRACE("Unregistering custom script named %s", cs->threadName);
				RemoveScriptFromQueue(cs, activeThreadQueue);
				cs->isActive = false;
				delete cs;
			});
		customScripts.clear();
		if (customMission)
		{
			TRACE("Unregistering custom mission named %s", customMission->threadName);
			RemoveScriptFromQueue(customMission, activeThreadQueue);
			customMission->isActive = false;
			delete customMission;
			customMission = nullptr;
		}
	}

	void ScriptEngine::UnregisterAllScripts()
	{
		TRACE("Unregistering all custom scripts");
		std::for_each(customScripts.begin(), customScripts.end(),
			[this](CCustomScript *cs)
			{
				RemoveScriptFromQueue(cs, activeThreadQueue);
				cs->isActive = false;
			});
	}

	void ScriptEngine::ReregisterAllScripts()
	{
		TRACE("Reregistering all custom scripts");
		std::for_each(customScripts.begin(), customScripts.end(),
			[this](CCustomScript *cs)
			{
				AddScriptToQueue(cs, activeThreadQueue);
				cs->isActive = true;
			});
	}

	ScriptEngine::ScriptEngine() : customMission(nullptr) { }

	ScriptEngine::~ScriptEngine()
	{
		RemoveAllCustomScripts();
	}

	CCustomScript::CCustomScript(const char *filename, bool is_mission)
		: CScriptThread::CScriptThread(), saving_enabled(false), OK(false),
		lastFoundActor(0), lastFoundVehicle(0), lastFoundObject(0)
	{
		IsCustom = 1; missionFlag = MissionCleanUpFlag = is_mission;
		TRACE("Loading custom script %s...", filename);
		try
		{
			using std::ios;
			std::ifstream is(filename, std::ios::binary);
			is.exceptions(std::ios::badbit | std::ios::failbit);
			std::size_t length;
			is.seekg(0, std::ios::end);
			length = is.tellg();
			is.seekg(0, std::ios::beg);
			if (is_mission)
			{
				if (*MissionLoaded)
					throw std::logic_error("Starting of custom mission when other mission loaded");
				*MissionLoaded = 1;
				baseIp = ip = missionBlock;
			}
			else
				baseIp = ip = new BYTE[length];
			is.read(reinterpret_cast<char *>(baseIp), length);
			memcpy(threadName, filename, sizeof(threadName));
			threadName[7] = '\0';
			checksum = crc32(baseIp, length);
			OK = true;
		}
		catch (std::exception& ex)
		{
			std::string message = std::string("Error during loading of custom script ") +
				filename + " occured.\n" 
				"Error message: " + ex.what();
			Error(message.c_str());
		}
		catch(...)
		{
			std::string message = std::string("Unknown error during loading of custom script ") +
				filename + " occured.";
			Error(message.c_str());
		}
	}

	CCustomScript::~CCustomScript()
	{
		if (baseIp && !this->missionFlag) delete[] baseIp;
	}
}
