summaryrefslogtreecommitdiff
path: root/script/core/command/autoRequire.lua
blob: 9f3ff92970f9cc2f8b01856bd914f4be9cd47d26 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
local files  = require 'files'
local furi   = require 'file-uri'
local config = require 'config'
local rpath  = require 'workspace.require-path'
local client = require 'client'
local lang   = require 'language'
local guide  = require 'parser.guide'

local function inComment(state, pos)
    for _, comm in ipairs(state.comms) do
        if comm.start <= pos and comm.finish >= pos then
            return true
        end
        if comm.start > pos then
            break
        end
    end
    return false
end

local function findInsertRow(uri)
    local text  = files.getText(uri)
    local state = files.getState(uri)
    if not state or not text then
        return
    end
    local lines = state.lines
    local fmt   = {
        pair = false,
        quot = '"',
        col  = nil,
    }
    local row
    for i = 0, #lines do
        if inComment(state, guide.positionOf(i, 0)) then
            goto CONTINUE
        end
        local ln = lines[i]
        local lnText = text:match('[^\r\n]*', ln)
        if not lnText:find('require', 1, true) then
            if row then
                break
            end
            if  not lnText:match '^local%s'
            and not lnText:match '^%s*$'
            and not lnText:match '^%-%-' then
                break
            end
        else
            row = i + 1
            local lpPos = lnText:find '%('
            if lpPos then
                fmt.pair = true
            else
                fmt.pair = false
            end
            local quot = lnText:match [=[(['"])]=]
            fmt.quot = quot or fmt.quot
            local eqPos = lnText:find '='
            if eqPos then
                fmt.col = eqPos
            end
        end
        ::CONTINUE::
    end
    return row or 0, fmt
end

---@async
local function askAutoRequire(uri, visiblePaths)
    local selects = {}
    local nameMap = {}
    for _, visible in ipairs(visiblePaths) do
        local expect = visible.name
        local select = lang.script(expect)
        if not nameMap[select] then
            nameMap[select] = expect
            selects[#selects+1] = select
        end
    end
    local disable = lang.script.COMPLETION_DISABLE_AUTO_REQUIRE
    selects[#selects+1] = disable

    local result = client.awaitRequestMessage('Info'
        , lang.script.COMPLETION_ASK_AUTO_REQUIRE
        , selects
    )
    if not result then
        return
    end
    if result == disable then
        client.setConfig {
            {
                key    = 'Lua.completion.autoRequire',
                action = 'set',
                value  = false,
                uri    = uri,
            }
        }
        return
    end
    return nameMap[result]
end

local function applyAutoRequire(uri, row, name, result, fmt)
    local quotedResult = ('%q'):format(result)
    if fmt.quot == "'" then
        quotedResult = ([['%s']]):format(quotedResult:sub(2, -2)
            :gsub([[']], [[\']])
            :gsub([[\"]], [["]])
        )
    end
    if fmt.pair then
        quotedResult = ('(%s)'):format(quotedResult)
    else
        quotedResult = (' %s'):format(quotedResult)
    end
    local sp = ' '
    local text = ('local %s'):format(name)
    if fmt.col and fmt.col > #text then
        sp = (' '):rep(fmt.col - #text - 1)
    end
    text = ('local %s%s= require%s\n'):format(name, sp, quotedResult)
    client.editText(uri, {
        {
            start  = guide.positionOf(row, 0),
            finish = guide.positionOf(row, 0),
            text   = text,
        }
    })
end

---@async
return function (data)
    ---@type uri
    local uri    = data.uri
    local target = data.target
    local name   = data.name
    local requireName   = data.requireName
    local state  = files.getState(uri)
    if not state then
        return
    end

    local path = furi.decode(target)
    local visiblePaths = rpath.getVisiblePath(uri, path)
    if not visiblePaths or #visiblePaths == 0 then
        return
    end
    table.sort(visiblePaths, function (a, b)
        return #a.name < #b.name
    end)

    if not requireName then
        requireName = askAutoRequire(uri, visiblePaths)
        if not requireName then
            return
        end
    end

    local offset, fmt = findInsertRow(uri)
    if offset and fmt then
        applyAutoRequire(uri, offset, name, requireName, fmt)
    end
end