/* libw32dl.c -- dlopen and friends for mingw32 */

/* Copyright (c) 2007 Ian Piumarta
 * All rights reserved.
 * 
 * Permission is hereby granted, free of charge, to any person obtaining a
 * copy of this software and associated documentation files (the 'Software'),
 * to deal in the Software without restriction, including without limitation
 * the rights to use, copy, modify, merge, publish, distribute, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, provided that the above copyright notice(s) and this
 * permission notice appear in all copies of the Software.  Including the
 * above copyright notice(s) and this permission notice in supporting
 * documentation would be appreciated but is not required.
 *
 * THE SOFTWARE IS PROVIDED 'AS IS'.  USE ENTIRELY AT YOUR OWN RISK.
 *
 * Last edited: 2007-09-10 15:27:01 by piumarta on cygwin.piumarta.com
 */


#include "w32dlfcn.h"

#include <windows.h>
#include <process.h>
#include <tlhelp32.h>

#if 0
# include <stdio.h>
# define dprintf(fmt, args...)	fprintf(stdout, fmt, ##args)
#else
# define dprintf(fmt, args...)
#endif

struct dll
{
  char		*file;
  int		 mode;
  void		*handle;
  int		 refCount;
  struct dll	*next;
};

static struct dll *dlls= 0;
static HANDLE	   dlmain= 0;
static int	   dlinitialised= 0;
static int	   dlerrno= 0;

static struct dll *dllNew(const char *file, int mode, void *handle)
{
  struct dll *dll= malloc(sizeof(struct dll));
  dll->file    	= strdup(file);
  dll->mode    	= mode;
  dll->handle  	= handle;
  dll->refCount	= 1;
  dll->next     = 0;
  return dll;
}

static void dllFree(struct dll *dll)
{
  free(dll->file);
  free(dll);
}

void dlinit(struct dll *list)
{
  dlinitialised= 1;
  dlmain= GetModuleHandle(0);
  if (list)
    dlls= list;
  else
    {
      /* find all loaded dlls */
      HANDLE snapshot;
      if (INVALID_HANDLE_VALUE != (snapshot= CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, getpid())))
	{
	  MODULEENTRY32 me32;
	  me32.dwSize= sizeof(MODULEENTRY32);
	  if (Module32First(snapshot, &me32))
	    do
	      dlopen(me32.szModule, RTLD_NOW | RTLD_GLOBAL);
	    while (Module32Next(snapshot, &me32));
	  CloseHandle(snapshot);
	}
    }
}

void *dlopen(const char *file, int mode)
{
  HANDLE handle= 0;
  struct dll *dll, **dllp;
  if (!dlinitialised) dlinit(0);
  if (!file)
    {
      dprintf("dlopen 0 -> %p\n", dlmain);
      return dlmain;
    }
  for (dllp= &dlls;  (dll= *dllp);  dllp= &dll->next)
    if (!strcmp(file, dll->file))
      {
	dll->refCount++;
	dprintf("dlopen %s -> %p ++ref\n", file, handle);
	return dll->handle;
      }
  /* dllp points to last dll->next in dlls list */
  {
    unsigned int errorMode= SetErrorMode(SEM_FAILCRITICALERRORS);
    SetErrorMode(errorMode | SEM_FAILCRITICALERRORS);
    handle= GetModuleHandle(file);
    if (!handle)
      {
	handle= LoadLibrary(file);
	if (handle)
	  {
	    void *init;
	    if ((init= GetProcAddress(handle, "dlinit")))	((void (*)(struct dll *))init)(dlls);
	    if ((init= GetProcAddress(handle, "_init")))	((void (*)(void))init)();
	  }
      }
    dlerrno= GetLastError();
    SetErrorMode(errorMode);
  }
  if (handle) *dllp= dllNew(file, mode, (void *)handle);	/* append */
  dprintf("dlopen %s -> %p %d\n", file, handle, mode);
  return handle;
}

void *dlsym(void *__restrict__ handle, const char *__restrict__ name)
{
  void *addr= 0;
  struct dll *dll;
  if (!dlinitialised) dlinit(0);
  if (handle)
    {
      addr= GetProcAddress(handle, name);
      dprintf("dlsym %p \"%s\" -> %p\n", handle, name, addr);
      return addr;
    }
  /* handle is RTLD_DEFAULT */
  if ((addr= GetProcAddress(dlmain, name)))
    {
      dprintf("dlsym 0 \"%s\" -> %p\n", name, addr);
      return addr;
    }
  for (dll= dlls;  dll;  dll= dll->next)
    if ((!(dll->mode & RTLD_LOCAL)) && (addr= GetProcAddress(dll->handle, name)))
      {
	dprintf("dlsym <%s> %p %d \"%s\" -> %p\n", (dll->file ? dll->file : "[process]"), dll->handle, dll->mode, name, addr);
	return addr;
      }
  dlerrno= GetLastError();
  dprintf("dlsym %p \"%s\" -> FAIL\n", handle, name);
  return 0;
}

char *dlerror(void)
{
  static char message[128];
  FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
		NULL,
		dlerrno,
		MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
		&message[0], sizeof(message),
		NULL);
  return message;
}

int dlclose(void *handle)
{
  if (!dlinitialised) dlinit(0);
  if (handle && handle != dlmain)
    {
      struct dll *dll, **dllp;
      for (dllp= &dlls;  (dll= *dllp);  dllp= &dll->next)
	if (dll->handle == handle)
	  if (dll->refCount-- <= 1)
	    {
	      void *fini= GetProcAddress(handle, "_fini");
	      if (fini) ((void(*)(void))fini)();
	      *dllp= dll->next;
	      FreeLibrary(handle);
	      dprintf("free %s\n", dll->file);
	      dllFree(dll);
	      break;
	    }
    }
  return 0;
}

