From c461d51bf40563c3b74c652d8789ae302ea7c9e1 Mon Sep 17 00:00:00 2001 From: sumneko Date: Sat, 12 Mar 2022 04:48:33 +0800 Subject: update --- script/parser/guide.lua | 21 ++++++++++++++++ script/vm/compiler.lua | 59 ++++++++++++++++++++++++++++++++++++++------ test/type_inference/init.lua | 27 ++++++++++++++++---- 3 files changed, 95 insertions(+), 12 deletions(-) diff --git a/script/parser/guide.lua b/script/parser/guide.lua index 73857149..f490b306 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -1182,4 +1182,25 @@ function m.isOOP(source) return false end +local baseTypeMap = { + ['unknown'] = true, + ['any'] = true, + ['true'] = true, + ['false'] = true, + ['nil'] = true, + ['boolean'] = true, + ['number'] = true, + ['string'] = true, + ['table'] = true, + ['function'] = true, + ['thread'] = true, + ['userdata'] = true, +} + +---@param str string +---@return boolean +function m.isBaseType(str) + return baseTypeMap[str] == true +end + return m diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index c46519fd..9da8a55b 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -304,7 +304,7 @@ local function bindDocs(source) end end if doc.type == 'doc.class' then - if source.type == 'local' + if (source.type == 'local' and not isParam) or (source._globalNode and guide.isSet(source)) then hasFounded = true nodeMgr.setNode(source, m.compileNode(doc)) @@ -407,7 +407,21 @@ local function selectNode(source, list, index) return nodeMgr.setNode(source, result) end -local function setCallBackNode(source, call, callNode, fixIndex) +---@param source parser.object +---@param node vm.node +---@return boolean +local function isValidCallArgNode(source, node) + if source.type == 'function' then + return node.type == 'doc.type.function' + end + if source.type == 'table' then + return node.type == 'doc.type.table' + or (node.type == 'global' and node.cate == 'type' and not guide.isBaseType(node.name)) + end + return false +end + +local function setCallArgNode(source, call, callNode, fixIndex) local valueMgr = require 'vm.value' local myIndex for i, arg in ipairs(call.args) do @@ -416,6 +430,7 @@ local function setCallBackNode(source, call, callNode, fixIndex) break end end + local eventIndex = 1 local eventArg = call.args[eventIndex + fixIndex] if eventArg and eventArg.dummy then @@ -423,11 +438,12 @@ local function setCallBackNode(source, call, callNode, fixIndex) eventArg = call.args[eventIndex + fixIndex] end local eventMap = valueMgr.getLiterals(eventArg) + for n in nodeMgr.eachNode(callNode) do if n.type == 'function' then local arg = n.args[myIndex] for fn in nodeMgr.eachNode(m.compileNode(arg)) do - if fn.type == 'doc.type.function' then + if isValidCallArgNode(source, fn) then nodeMgr.setNode(source, fn) end end @@ -440,7 +456,7 @@ local function setCallBackNode(source, call, callNode, fixIndex) or eventMap[event[1]] then local arg = n.args[myIndex] for fn in nodeMgr.eachNode(m.compileNode(arg)) do - if fn.type == 'doc.type.function' then + if isValidCallArgNode(source, fn) then nodeMgr.setNode(source, fn) end end @@ -452,7 +468,6 @@ end local compilerMap = util.switch() : case 'nil' : case 'boolean' - : case 'table' : case 'integer' : case 'number' : case 'string' @@ -463,6 +478,32 @@ local compilerMap = util.switch() : call(function (source) nodeMgr.setNode(source, source) end) + : case 'table' + : call(function (source) + nodeMgr.setNode(source, source) + + if source.parent.type == 'callargs' then + local call = source.parent.parent + local callNode = m.compileNode(call.node) + setCallArgNode(source, call, callNode, 0) + + if call.node.special == 'pcall' + or call.node.special == 'xpcall' then + local fixIndex = call.node.special == 'pcall' and 1 or 2 + callNode = m.compileNode(call.args[1]) + setCallArgNode(source, call, callNode, fixIndex) + end + end + + if source.parent.type == 'setglobal' + or source.parent.type == 'setlocal' + or source.parent.type == 'tablefield' + or source.parent.type == 'tableindex' + or source.parent.type == 'setfield' + or source.parent.type == 'setindex' then + nodeMgr.setNode(source, m.compileNode(source.parent)) + end + end) : case 'function' : call(function (source) nodeMgr.setNode(source, source) @@ -479,13 +520,13 @@ local compilerMap = util.switch() if source.parent.type == 'callargs' then local call = source.parent.parent local callNode = m.compileNode(call.node) - setCallBackNode(source, call, callNode, 0) + setCallArgNode(source, call, callNode, 0) if call.node.special == 'pcall' or call.node.special == 'xpcall' then local fixIndex = call.node.special == 'pcall' and 1 or 2 callNode = m.compileNode(call.args[1]) - setCallBackNode(source, call, callNode, fixIndex) + setCallArgNode(source, call, callNode, fixIndex) end end end) @@ -578,6 +619,10 @@ local compilerMap = util.switch() if source.value then nodeMgr.setNode(source, m.compileNode(source.value)) end + + m.compileByParentNode(source.parent, guide.getKeyName(source), function (src) + nodeMgr.setNode(source, m.compileNode(src)) + end) end) : case 'field' : case 'method' diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 7d3e3280..bdd13760 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -17,7 +17,8 @@ local function getSource(pos) or source.type == 'getglobal' or source.type == 'field' or source.type == 'method' - or source.type == 'function' then + or source.type == 'function' + or source.type == 'table' then result = source end end) @@ -1210,14 +1211,14 @@ local x = 1 ]] -TEST 'any' [[ +TEST 'unknown' [[ ---@return number local function f(x) local = x() end ]] -TEST 'any' [[ +TEST 'unknown' [[ local mt ---@return number @@ -1226,7 +1227,7 @@ function mt:f() end local = mt() ]] -TEST 'any' [[ +TEST 'unknown' [[ local ---@class X @@ -1240,7 +1241,7 @@ local mt function mt:f() end ]] -TEST 'any' [[ +TEST 'unknown' [[ local ---@type number @@ -1270,6 +1271,22 @@ mt:loop(function () end) ]] +TEST 'C' [[ +---@class D +---@field y integer # D comment + +---@class C +---@field x integer # C comment +---@field d D + +---@param c C +local function f(c) end + +f + x = , +} +]] + TEST 'integer' [[ ---@class D ---@field y integer # D comment -- cgit v1.2.3