#!/usr/bin/env python3

import os
import sys

fncompiler = sys.argv[1]
fncompilerc = fncompiler+'.c'

tags = fncompiler.split('/')[1].split('+')

if tags[0] == 'amd64':

  def extended(i):
    return i >= 0x80000000

  def WANT(i,j):
    if extended(i):
      return f'WANT_EXT{i-0x80000000}_{j}'
    return f'WANT_{i}_{j}'

  # (i,j) is cpuid level 1 returning [eax,ebx,ecx,edx][j]
  # e.g. (1,2) is cpuid level 1 returning ecx
  cpuidbits = {
    (1,3): 'fpu vme de pse tsc msr pae mce cmpxchg8b apic reserved sysentersysexit mtrr pge mca cmov pat pse36 psn clflush reserved ds acpi mmx fxsr sse sse2 selfsnoop htt tm reserved pbe',
    (1,2): 'sse3 pclmulqdq dtes64 monitor dscpl vmx smx eist tm2 ssse3 l1contextid debuginterface fma cmpxchg16b xtprupdatecontrol perfcapabilities reserved pcid dca sse41 sse42 x2apic movbe popcnt tscdeadline aesni xsave osxsave avx f16c rdrand notused',
    (7,1): 'fsgsbase tscadjust sgx bmi1 hle avx2 fdpexcptnonly smep bmi2 erms invpcid rtm rdtm fcsfdsdeprecation mpx rdta avx512f avx512dq rdseed adx smap avx512ifma reserved clflushopt clwb intelproctrace avx512pf avx512er avx512cd sha avx512bw avx512vl',
    (7,2): 'reserved avx512vbmi umip pku ospke waitpkg avx512vbmi2 cetss gfni vaes vpclmulqdq avx512vnni avx512bitalg reserved avx512vpopc',
    (0x80000001,2): 'lahfsahf64 cmplegacy svm extapicspace altmovcr8 lzcnt sse4a misalignsse prefetchw osvw ibs xop skinit wdt reserved lwp fma4 tce',
  }

  want = {}

  def do(tag):
    found = 0
    for i,j in cpuidbits:
      for k,name in enumerate(cpuidbits[i,j].split()):
        if name == tag:
          found += 1
          if (i,j) not in want: want[i,j] = set()
          want[i,j].add((k,name))
    assert found == 1

  if tags[0] == 'amd64':
    do('mmx')
    do('sse')
    do('sse2')
  else:
    assert tags[0] == 'default'

  for tag in tags[1:]:
    do(tag)

  usexcr = 'avx' in tags or 'avx2' in tags
  if usexcr:
    do('osxsave')

  result = ''

  result += r'''/*
gcc has __builtin_cpu_supports("avx2")
but implemented it incorrectly until 2018:
https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85100

cannot expect all of those machines to have upgraded gcc yet

furthermore, why is checking just for avx2 enough?
has intel guaranteed that it will never introduce
a cpu with avx2 instructions and without (e.g.) sse4.2?

so manually check cpuid and xgetbv here
and include all the "lower" instruction sets
rather than trying to guess which ones are implied
*/

#include <inttypes.h>

#ifdef __FILC__
#include <cpuid.h>
#include <stdfil.h>
#elif defined(_MSC_VER)
#include <immintrin.h>
#include <intrin.h>
#endif

static void cpuid0(uint32_t func,uint32_t *a,uint32_t *b,uint32_t *c,uint32_t *d)
{
#ifdef __FILC__
  __get_cpuid(func,a,b,c,d);
#elif defined(_MSC_VER)
  uint32_t x[4];
  __cpuid(x,func);
  *a = x[0];
  *b = x[1];
  *c = x[2];
  *d = x[3];
#else
  asm volatile("cpuid":"=a"(*a),"=b"(*b),"=c"(*c),"=d"(*d):"a"(func),"c"(0));
#endif
}
'''

  if usexcr:
    result += '''
static uint64_t xgetbv0(void)
{
#ifdef __FILC__
  return zxgetbv();
#elif defined(_MSC_VER)
  return _xgetbv(0);
#else
  uint32_t a,d;
  asm(".byte 15;.byte 1;.byte 208":"=a"(a),"=d"(d):"c"(0));
  return a|(((uint64_t)d)<<32);
#endif
}
'''

  for i,j in cpuidbits:
    if (i,j) not in want: continue
    macro = [f'(1<<{k})' for k,tag in sorted(want[i,j])]
    macro = '|'.join(macro)
    macro = f'#define {WANT(i,j)} ({macro})'
    comment = [f'{k}={tag}' for k,tag in sorted(want[i,j])]
    comment = '; '.join(comment)
    comment = f'/* {comment} */'
    result += '\n'
    result += macro+'\n'
    result += comment+'\n'

  if usexcr:
    result += r'''
#define WANT_XCR ((1<<1)|(1<<2))
/* 1=xmm; 2=ymm */
'''

  result += r'''
int supports(void)
{
  uint32_t cpuidmax,id0,id1,id2;
  uint32_t feature0,feature1,feature2,feature3;
'''

  if usexcr:
    result += r'''  uint64_t xcr;
'''

  maxlevel = max(i for (i,j) in want if not extended(i))
  result += fr'''
  cpuid0(0,&cpuidmax,&id0,&id1,&id2);
  if (cpuidmax < {maxlevel}) return 0;
'''

  if any(i > 0x10000 for (i,j) in want):
    maxextlevel = max(i for (i,j) in want if extended(i))

    result += fr'''
  cpuidmax = feature1 = feature2 = feature3 = 0;
  cpuid0(0x80000000,&cpuidmax,&feature1,&feature2,&feature3);
  if (cpuidmax < {hex(maxextlevel)}) return 0;
'''

  for i in sorted(set(i for (i,j) in want)):
    result += '\n'
    result += f'  cpuid0({hex(i) if extended(i) else i},&feature0,&feature1,&feature2,&feature3);\n'
    for j in sorted(set(j2 for (i2,j2) in want if i2 == i)):
      result += f'  if ({WANT(i,j)} != ({WANT(i,j)} & feature{j})) return 0;\n'

  if usexcr:
    result += r'''
  xcr = xgetbv0();
  if (WANT_XCR != (WANT_XCR & xcr)) return 0;
'''

  result += r'''
  return 1;
}
'''

  with open(fncompilerc,'w') as f:
    f.write(result)

# ===== fncompiler

with open(fncompiler,'w') as f:
  if len(sys.argv) >= 3:
    ccarg = sys.argv[2]
    if tags[0] == 'amd64':
      ccarg += f' -mmmx -msse -msse2'
    for tag in tags[1:]:
      if tag == 'sse41':
        ccarg += f' -msse4.1'
      elif tag == 'sse42':
        ccarg += f' -msse4.2'
      elif tag == 'bmi1':
        ccarg += f' -mbmi'
      else:
        ccarg += f' -m{tag}'
    for arg in sys.argv[3:]:
      ccarg += ' '+arg
    if os.getenv('CLANGFIRST'):
      f.write(f'clang -Wall -fPIC -fwrapv -Qunused-arguments {ccarg}\n')
      f.write(f'gcc -Wall -fPIC -fwrapv {ccarg}\n')
    else:
      f.write(f'gcc -Wall -fPIC -fwrapv {ccarg}\n')
      f.write(f'clang -Wall -fPIC -fwrapv -Qunused-arguments {ccarg}\n')
