summaryrefslogtreecommitdiff
path: root/script/vm/compiler.lua
diff options
context:
space:
mode:
Diffstat (limited to 'script/vm/compiler.lua')
-rw-r--r--script/vm/compiler.lua231
1 files changed, 185 insertions, 46 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index d4309cc9..2253c83a 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -5,12 +5,17 @@ local rpath = require 'workspace.require-path'
local files = require 'files'
---@class vm
local vm = require 'vm.vm'
+local plugin = require 'plugin'
---@class parser.object
----@field _compiledNodes boolean
----@field _node vm.node
----@field cindex integer
----@field func parser.object
+---@field _compiledNodes boolean
+---@field _node vm.node
+---@field cindex integer
+---@field func parser.object
+---@field hideView boolean
+---@field package _returns? parser.object[]
+---@field package _callReturns? parser.object[]
+---@field package _asCache? parser.object[]
-- 该函数有副作用,会给source绑定node!
---@param source parser.object
@@ -28,6 +33,12 @@ function vm.bindDocs(source)
end
if doc.type == 'doc.class' then
vm.setNode(source, vm.compileNode(doc))
+ for j = i + 1, #docs do
+ local overload = docs[j]
+ if overload.type == 'doc.overload' then
+ overload.overload.hideView = true
+ end
+ end
return true
end
if doc.type == 'doc.param' then
@@ -55,6 +66,9 @@ function vm.bindDocs(source)
vm.setNode(source, vm.compileNode(ast))
return true
end
+ if doc.type == 'doc.overload' then
+ vm.setNode(source, vm.compileNode(doc))
+ end
end
return false
end
@@ -473,6 +487,7 @@ function vm.getReturnOfFunction(func, index)
func._returns = {}
end
if not func._returns[index] then
+ ---@diagnostic disable-next-line: missing-fields
func._returns[index] = {
type = 'function.return',
parent = func,
@@ -511,7 +526,7 @@ local function getReturnOfSetMetaTable(args)
node:merge(vm.compileNode(tbl))
end
if mt then
- vm.compileByParentNode(mt, '__index', function (src)
+ vm.compileByParentNodeAll(mt, '__index', function (src)
for n in vm.compileNode(src):eachObject() do
if n.type == 'global'
or n.type == 'local'
@@ -522,6 +537,8 @@ local function getReturnOfSetMetaTable(args)
end
end)
end
+ --过滤nil
+ node:remove 'nil'
return node
end
@@ -568,6 +585,7 @@ local function getReturn(func, index, args)
end
if not func._callReturns[index] then
local call = func.parent
+ ---@diagnostic disable-next-line: missing-fields
func._callReturns[index] = {
type = 'call.return',
parent = call,
@@ -637,8 +655,9 @@ end
---@param source parser.object | vm.variable
---@param key string|vm.global|vm.ANY
----@param pushResult fun(source: parser.object)
-function vm.compileByParentNode(source, key, pushResult)
+---@return parser.object[] docedResults
+---@return parser.object[] commonResults
+function vm.getNodesOfParentNode(source, key)
local parentNode = vm.compileNode(source)
local docedResults = {}
local commonResults = {}
@@ -691,6 +710,16 @@ function vm.compileByParentNode(source, key, pushResult)
end)
end
+ return docedResults, commonResults
+end
+
+-- 遍历所有字段(按照优先级)
+---@param source parser.object | vm.variable
+---@param key string|vm.global|vm.ANY
+---@param pushResult fun(source: parser.object)
+function vm.compileByParentNode(source, key, pushResult)
+ local docedResults, commonResults = vm.getNodesOfParentNode(source, key)
+
if #docedResults > 0 then
for _, res in ipairs(docedResults) do
pushResult(res)
@@ -703,6 +732,21 @@ function vm.compileByParentNode(source, key, pushResult)
end
end
+-- 遍历所有字段(无视优先级)
+---@param source parser.object | vm.variable
+---@param key string|vm.global|vm.ANY
+---@param pushResult fun(source: parser.object)
+function vm.compileByParentNodeAll(source, key, pushResult)
+ local docedResults, commonResults = vm.getNodesOfParentNode(source, key)
+
+ for _, res in ipairs(docedResults) do
+ pushResult(res)
+ end
+ for _, res in ipairs(commonResults) do
+ pushResult(res)
+ end
+end
+
---@param list parser.object[]
---@param index integer
---@return vm.node
@@ -728,6 +772,11 @@ function vm.selectNode(list, index)
if not exp then
return vm.createNode(vm.declareGlobal('type', 'nil')), nil
end
+
+ if vm.bindDocs(list) then
+ return vm.compileNode(list), exp
+ end
+
---@type vm.node?
local result
if exp.type == 'call' then
@@ -834,52 +883,69 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex)
end
end
- for n in callNode:eachObject() do
- if n.type == 'function' then
- ---@cast n parser.object
- local sign = vm.getSign(n)
+ ---@param n parser.object
+ local function dealDocFunc(n)
+ local myEvent
+ if n.args[eventIndex] then
+ local argNode = vm.compileNode(n.args[eventIndex])
+ myEvent = argNode:get(1)
+ end
+ if not myEvent
+ or not eventMap
+ or myIndex <= eventIndex
+ or myEvent.type ~= 'doc.type.string'
+ or eventMap[myEvent[1]] then
local farg = getFuncArg(n, myIndex)
if farg then
for fn in vm.compileNode(farg):eachObject() do
if isValidCallArgNode(arg, fn) then
- if fn.type == 'doc.type.function' then
- ---@cast fn parser.object
- if sign then
- local generic = vm.createGeneric(fn, sign)
- local args = {}
- for i = fixIndex + 1, myIndex - 1 do
- args[#args+1] = call.args[i]
- end
- local resolvedNode = generic:resolve(guide.getUri(call), args)
- vm.setNode(arg, resolvedNode)
- goto CONTINUE
- end
- end
vm.setNode(arg, fn)
- ::CONTINUE::
end
end
end
end
- if n.type == 'doc.type.function' then
- ---@cast n parser.object
- local myEvent
- if n.args[eventIndex] then
- local argNode = vm.compileNode(n.args[eventIndex])
- myEvent = argNode:get(1)
- end
- if not myEvent
- or not eventMap
- or myIndex <= eventIndex
- or myEvent.type ~= 'doc.type.string'
- or eventMap[myEvent[1]] then
- local farg = getFuncArg(n, myIndex)
- if farg then
- for fn in vm.compileNode(farg):eachObject() do
- if isValidCallArgNode(arg, fn) then
- vm.setNode(arg, fn)
+ end
+
+ ---@param n parser.object
+ local function dealFunction(n)
+ local sign = vm.getSign(n)
+ local farg = getFuncArg(n, myIndex)
+ if farg then
+ for fn in vm.compileNode(farg):eachObject() do
+ if isValidCallArgNode(arg, fn) then
+ if fn.type == 'doc.type.function' then
+ ---@cast fn parser.object
+ if sign then
+ local generic = vm.createGeneric(fn, sign)
+ local args = {}
+ for i = fixIndex + 1, myIndex - 1 do
+ args[#args+1] = call.args[i]
+ end
+ local resolvedNode = generic:resolve(guide.getUri(call), args)
+ vm.setNode(arg, resolvedNode)
+ goto CONTINUE
end
end
+ vm.setNode(arg, fn)
+ ::CONTINUE::
+ end
+ end
+ end
+ end
+
+ for n in callNode:eachObject() do
+ if n.type == 'function' then
+ ---@cast n parser.object
+ dealFunction(n)
+ elseif n.type == 'doc.type.function' then
+ ---@cast n parser.object
+ dealDocFunc(n)
+ elseif n.type == 'global' and n.cate == 'type' then
+ ---@cast n vm.global
+ local overloads = vm.getOverloadsByTypeName(n.name, guide.getUri(arg))
+ if overloads then
+ for _, func in ipairs(overloads) do
+ dealDocFunc(func)
end
end
end
@@ -966,6 +1032,55 @@ local function compileForVars(source, target)
end
---@param source parser.object
+local function compileFunctionParam(func, source)
+ -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
+ local funcNode = vm.compileNode(func)
+ for n in funcNode:eachObject() do
+ if n.type == 'doc.type.function' then
+ for index, arg in ipairs(n.args) do
+ if func.args[index] == source then
+ local argNode = vm.compileNode(arg)
+ for an in argNode:eachObject() do
+ if an.type ~= 'doc.generic.name' then
+ vm.setNode(source, an)
+ end
+ end
+ return true
+ end
+ end
+ end
+ end
+ if func.parent.type == 'local' then
+ local refs = func.parent.ref
+ local findCall
+ if refs then
+ for i, ref in ipairs(refs) do
+ if ref.parent.type == 'call' then
+ findCall = ref.parent
+ break
+ end
+ end
+ end
+ if findCall and findCall.args then
+ local index
+ for i, arg in ipairs(source.parent) do
+ if arg == source then
+ index = i
+ break
+ end
+ end
+ if index then
+ local callerArg = findCall.args[index]
+ if callerArg then
+ vm.setNode(source, vm.compileNode(callerArg))
+ return true
+ end
+ end
+ end
+ end
+end
+
+---@param source parser.object
local function compileLocal(source)
local myNode = vm.setNode(source, source)
@@ -992,6 +1107,7 @@ local function compileLocal(source)
vm.setNode(source, vm.compileNode(source.value))
end
end
+
-- function x.y(self, ...) --> function x:y(...)
if source[1] == 'self'
and not hasMarkDoc
@@ -1027,6 +1143,7 @@ local function compileLocal(source)
vm.setNode(source, vm.declareGlobal('type', 'any'))
end
end
+
-- for x in ... do
if source.parent.type == 'in' then
compileForVars(source.parent, source)
@@ -1066,6 +1183,12 @@ local function compileLocal(source)
end
end
+ if source.value
+ and source.value.type == 'nil'
+ and not myNode:hasKnownType() then
+ vm.setNode(source, vm.compileNode(source.value))
+ end
+
myNode.hasDefined = hasMarkDoc or hasMarkParam or hasMarkValue
end
@@ -1112,6 +1235,9 @@ local compilerSwitch = util.switch()
end)
: case 'table'
: call(function (source)
+ if vm.bindAs(source) then
+ return
+ end
vm.setNode(source, source)
if source.parent.type == 'callargs' then
@@ -1119,6 +1245,16 @@ local compilerSwitch = util.switch()
vm.compileCallArg(source, call)
end
+ if source.parent.type == 'return' then
+ local myIndex = util.arrayIndexOf(source.parent, source)
+ ---@cast myIndex -?
+ local parentNode = vm.selectNode(source.parent, myIndex)
+ if not parentNode:isEmpty() then
+ vm.setNode(source, parentNode)
+ return
+ end
+ end
+
if source.parent.type == 'setglobal'
or source.parent.type == 'local'
or source.parent.type == 'setlocal'
@@ -1315,8 +1451,8 @@ local compilerSwitch = util.switch()
hasMarkDoc = vm.bindDocs(source)
end
+ local key = guide.getKeyName(source)
if not hasMarkDoc then
- local key = guide.getKeyName(source)
if key then
vm.compileByParentNode(source.node, key, function (src)
if src.type == 'doc.field'
@@ -1341,8 +1477,11 @@ local compilerSwitch = util.switch()
end)
end
- if not hasMarkDoc and source.value then
- vm.setNode(source, vm.compileNode(source.value))
+ if source.value then
+ if not hasMarkDoc
+ or (type(key) == 'string' and util.stringStartWith(key, '__')) then
+ vm.setNode(source, vm.compileNode(source.value))
+ end
end
end)
@@ -1617,7 +1756,7 @@ local compilerSwitch = util.switch()
if state.type == 'doc.return'
or state.type == 'doc.param' then
local func = state.bindSource
- if func.type == 'function' then
+ if func and func.type == 'function' then
local node = guide.getFunctionSelfNode(func)
if node then
vm.setNode(source, vm.compileNode(node))