Skip to content

Commit

Permalink
more work
Browse files Browse the repository at this point in the history
  • Loading branch information
CapsAdmin committed May 16, 2024
1 parent c850b8d commit beb51f5
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,10 @@ do
if not is_dots(data[0].cFileName) then
out[i] = ffi.string(data[0].cFileName)
i = i + 1
end until ffi.C.FindNextFileA(handle, data) == 0
end
until ffi.C.FindNextFileA(handle, data) == 0

if ffi.C.FindClose(handle) == 0 then return nil, last_error() end
if ffi.C.FindClose(assert(handle)) == 0 then return nil, last_error() end

return out
end
Expand All @@ -182,4 +183,4 @@ do
end
end

return fs
return fs
102 changes: 68 additions & 34 deletions nattlua/c_declarations/analyzer.lua
Original file line number Diff line number Diff line change
Expand Up @@ -170,40 +170,49 @@ local function cast(self, node, out)
local env = self.env
local analyzer = self.analyzer
local typs = self.typs
local vars = self.vars

if node.type == "array" then
local size

if node.size == "?" then
size = table.remove(self.dollar_signs_vars, 1)
else
size = LNumber(tonumber(node.size) or math.huge)
end

return (
env.FFIArray:Call(
analyzer,
Tuple({LNumber(tonumber(node.size) or math.huge), cast(self, assert(node.of), out)})
):Unpack()
env.FFIArray:Call(analyzer, Tuple({size, cast(self, assert(node.of), out)})):Unpack()
)
elseif node.type == "pointer" then
if node.of.type == "type" and #node.of.modifiers == 1 and node.of.modifiers[1] == "void" then
if
node.of.type == "type" and
#node.of.modifiers == 1 and
node.of.modifiers[1] == "void"
then
return Any() -- TODO: is this true?
end

local res = (env.FFIPointer:Call(analyzer, Tuple({cast(self, assert(node.of), out)})):Unpack())

if node.of.type == "type" and node.of.modifiers[1] == "const" and node.of.modifiers[2] == "char" then
if self.FUNCTION_ARGUMENT then
return Union({res, String(), Nil()})
end

if
node.of.type == "type" and
node.of.modifiers[1] == "const" and
node.of.modifiers[2] == "char"
then
return Union({res, String(), Nil()})
end

return Union({res, Nil()})
elseif node.type == "type" then
for _, v in ipairs(node.modifiers) do
if type(v) == "table" then

-- only catch struct, union and enum TYPE declarations
if v.type == "struct" or v.type == "union" then
local tbl

if v.fields then
tbl = Table()


local ident = v.identifier

if not ident and #node.modifiers > 0 then
Expand All @@ -212,18 +221,29 @@ local function cast(self, node, out)

self.current_nodes = self.current_nodes or {}
table.insert(self.current_nodes, 1, {ident = ident, tbl = tbl})

--tbl:Set(LString("__id"), LString(("%p"):format({})))
for _, v in ipairs(v.fields) do
tbl:Set(LString(v.identifier), cast(self, v, out))
end

table.remove(self.current_nodes, 1)

table.insert(out, {identifier = ident, obj = tbl})
self.typs_write:Set(LString(ident), tbl)
else
tbl = typs:Get(LString(v.identifier)) or Table()
local current = self.current_nodes and self.current_nodes[1]

if current and current.ident == v.identifier then
-- recursion
tbl = current.tbl
else
tbl = typs:Get(LString(v.identifier)) or
self.typs_write:Get(LString(v.identifier)) or
Table()
end

table.insert(out, {identifier = v.identifier, obj = tbl})
self.typs_write:Set(LString(v.identifier), tbl)
end
elseif v.type == "enum" then
local tbl = Table()
Expand All @@ -235,6 +255,7 @@ local function cast(self, node, out)
end

table.insert(out, {identifier = v.identifier, obj = tbl})
self.typs_write:Set(LString(v.identifier), tbl)
end

-- catch variable declarations
Expand All @@ -247,22 +268,30 @@ local function cast(self, node, out)

local tbl = typs:Get(LString(ident))

if not tbl then tbl = self.typs_write:Get(LString(ident)) end

if not tbl and v.fields then
tbl = Table()
self.current_nodes = self.current_nodes or {}
table.insert(self.current_nodes, 1, {ident = ident, tbl = tbl})

for _, v in ipairs(v.fields) do
tbl:Set(LString(v.identifier), cast(self, v, out))
end

table.remove(self.current_nodes, 1)
end

if not tbl then
if not tbl and self.current_nodes then
local current = self.current_nodes[1]

if current and current.ident == ident then
-- recursion
tbl = current.tbl
end
end

if not tbl then tbl = Table() end

return (tbl)
elseif v.type == "enum" then
Expand All @@ -273,7 +302,7 @@ local function cast(self, node, out)
end
end
end

local t = node.modifiers[1]

if
Expand Down Expand Up @@ -328,7 +357,7 @@ local function cast(self, node, out)
return Boolean()
elseif t == "void" then
return Nil()
elseif t == "$" then
elseif t == "$" or t == "?" then
local res = table.remove(self.dollar_signs_vars, 1)
return res
elseif t == "va_list" then
Expand All @@ -340,11 +369,9 @@ local function cast(self, node, out)
local args = {}
local rets = {}

self.FUNCTION_ARGUMENT = true
for i, v in ipairs(node.args) do
table.insert(args, cast(self, v, out))
end
self.FUNCTION_ARGUMENT = false

return (Function(Tuple(args), Tuple({cast(self, assert(node.rets), out)})))
elseif node.type == "root" then
Expand All @@ -354,28 +381,35 @@ local function cast(self, node, out)
end
end


function META:AnalyzeRoot(ast, vars, typs)
vars = vars or Table()
typs = typs or Table()
self.typs = typs
self.Callback = function(node, real_node, typedef)
-- new output
self.typs = typs or Table()
self.vars = vars or Table()
local typs = Table()
local vars = Table()
self.typs_write = typs
self.vars_write = vars
self.Callback = function(node, real_node, typedef)
local out = {}
local obj = cast(self, node, out)
if typedef then
typs:Set(LString(real_node.tokens["potential_identifier"].value), obj)
else
vars:Set(LString(real_node.tokens["potential_identifier"] and real_node.tokens["potential_identifier"].value or "uhoh"), obj)
local ident = real_node.tokens["potential_identifier"] and
real_node.tokens["potential_identifier"].value

if not ident then ident = "uhhoh" end

if ident then
if typedef then
typs:Set(LString(ident), obj)
else
vars:Set(LString(ident), obj)
end
end

for _, typedef in ipairs(out) do
typs:Set(LString(assert(typedef.identifier)), typedef.obj)
end

end
self:WalkRoot(ast)

print(vars, typs)
return vars, typs
end

Expand Down
50 changes: 28 additions & 22 deletions nattlua/c_declarations/main.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
local cparser = {}
local table_print = require("nattlua.other.table_print")
local Function = require("nattlua.types.function").Function
local LuaTypeFunction = require("nattlua.types.function").LuaTypeFunction
local LNumber = require("nattlua.types.number").LNumber
Expand All @@ -16,7 +15,6 @@ local Nilable = require("nattlua.types.union").Nilable
local Tuple = require("nattlua.types.tuple").Tuple
local Boolean = require("nattlua.types.union").Boolean


local function C_DECLARATIONS()
local analyzer = assert(
require("nattlua.analyzer.context"):GetCurrentAnalyzer(),
Expand All @@ -25,6 +23,16 @@ local function C_DECLARATIONS()
local env = analyzer:GetScopeHelper(analyzer.function_scope)
return env.typesystem.ffi:Get(ConstString("C"))
end

local function C_DECLARATIONS_RUNTIME()
local analyzer = assert(
require("nattlua.analyzer.context"):GetCurrentAnalyzer(),
"no analyzer in context"
)
local env = analyzer:GetScopeHelper(analyzer.function_scope)
return env.runtime.ffi:Get(ConstString("C"))
end

local function parse2(c_code, mode, env, analyzer, ...)
local Lexer = require("nattlua.c_declarations.lexer").New
local Parser = require("nattlua.c_declarations.parser").New
Expand All @@ -33,9 +41,9 @@ local function parse2(c_code, mode, env, analyzer, ...)
local Code = require("nattlua.code").New
local Compiler = require("nattlua.compiler")

if mode == "typeof" then
c_code = "typedef " .. c_code .. " TYPEOF_CDECL;"
end
if mode == "typeof" then c_code = "typedef " .. c_code .. " TYPEOF_CDECL;" end

if mode == "ffinew" then c_code = c_code .. " VAR_NAME;" end

local code = Code(c_code, "test.c")
local lex = Lexer(code)
Expand All @@ -46,36 +54,38 @@ local function parse2(c_code, mode, env, analyzer, ...)
local emitter = Emitter({skip_translation = true})
local res = emitter:BuildCode(ast)
local a = Analyzer()

if parser.dollar_signs then
local function gen(...)
local new = {}

for i, v in ipairs(parser.dollar_signs) do
local ct = select(i, ...)
if not ct then
error("expected ctype at argument #" .. i, 2)
end

if not ct then error("expected ctype at argument #" .. i, 2) end

table.insert(new, 1, ct)
end

return new
end

a.dollar_signs_typs = gen(...)
a.dollar_signs_vars = gen(...)
end

a.env = env.typesystem
a.analyzer = analyzer
return a:AnalyzeRoot(ast)
return a:AnalyzeRoot(ast, C_DECLARATIONS_RUNTIME(), C_DECLARATIONS())
end


function cparser.sizeof(cdecl, len)
-- TODO: support non string sizeof
if jit and cdecl.Type == "string" and cdecl:IsLiteral() then
local analyzer = require("nattlua.analyzer.context"):GetCurrentAnalyzer()
local env = analyzer:GetScopeHelper(analyzer.function_scope)
local vars, typs = parse2(cdecl:GetData(), "typeof", env, analyzer)
local ctype = vars:GetData()[1].val

local ctype = typs:GetData()[1].val
local ffi = require("ffi")
local ok, val = pcall(ffi.sizeof, cdecl:GetData(), len and len:GetData() or nil)

Expand All @@ -88,28 +98,26 @@ end
function cparser.cdef(cdecl, ...)
assert(cdecl:IsLiteral(), "cdecl must be a string literal")
local analyzer = require("nattlua.analyzer.context"):GetCurrentAnalyzer()

local env = analyzer:GetScopeHelper(analyzer.function_scope)
local vars, typs = parse2(cdecl:GetData(), "cdef", env, analyzer, ...)

for _, kv in ipairs(typs:GetData()) do
analyzer:NewIndexOperator(C_DECLARATIONS(), kv.key, kv.val)
end

for _, kv in ipairs(vars:GetData()) do
analyzer:NewIndexOperator(C_DECLARATIONS(), kv.key, kv.val)
analyzer:NewIndexOperator(C_DECLARATIONS_RUNTIME(), kv.key, kv.val)
end

return nil
end

function cparser.cast(cdecl, src)
assert(cdecl:IsLiteral(), "cdecl must be a string literal")


local analyzer = require("nattlua.analyzer.context"):GetCurrentAnalyzer()
local env = analyzer:GetScopeHelper(analyzer.function_scope)
local vars, typs = parse2(cdecl:GetData(), "typeof", env, analyzer)
local ctype = vars:GetData()[1].val
local ctype = typs:GetData()[1].val

-- TODO, this tries to extract cdata from cdata | nil, since if we cast a valid pointer it cannot be invalid when returned
if ctype.Type == "union" then
Expand Down Expand Up @@ -143,7 +151,7 @@ function cparser.typeof(cdecl, ...)
local analyzer = require("nattlua.analyzer.context"):GetCurrentAnalyzer()
local env = analyzer:GetScopeHelper(analyzer.function_scope)
local vars, typs = parse2(cdecl:GetData(), "typeof", env, analyzer, ...)
local ctype = vars:GetData()[1].val
local ctype = typs:GetData()[1].val

-- TODO, this tries to extract cdata from cdata | nil, since if we cast a valid pointer it cannot be invalid when returned
if ctype.Type == "union" then
Expand Down Expand Up @@ -194,18 +202,16 @@ function cparser.get_type(cdecl, ...)
local analyzer = require("nattlua.analyzer.context"):GetCurrentAnalyzer()
local env = analyzer:GetScopeHelper(analyzer.function_scope)
local vars, typs = parse2(cdecl:GetData(), "typeof", env, analyzer, ...)
local ctype = vars:GetData()[1].val
local ctype = typs:GetData()[1].val
return ctype
end

function cparser.new(cdecl, ...)
local analyzer = require("nattlua.analyzer.context"):GetCurrentAnalyzer()
local env = analyzer:GetScopeHelper(analyzer.function_scope)
local vars, typs = parse2(cdecl:GetData(), "typeof", env, analyzer, ...)
local vars, typs = parse2(cdecl:GetData(), "ffinew", env, analyzer, ...)
local ctype = vars:GetData()[1].val

print(vars, typs)

if ctype.is_enum then return ... end

return ctype
Expand Down
Loading

0 comments on commit beb51f5

Please sign in to comment.