supa_mdx_lint/
rules.rs

1use anyhow::Result;
2use log::{debug, warn};
3use markdown::mdast::Node;
4use regex::Regex;
5use serde::Deserialize;
6use std::{collections::HashMap, fmt::Debug, marker::PhantomData};
7
8#[cfg(test)]
9use serde::Serialize;
10
11use crate::{
12    context::Context,
13    errors::{LintError, LintLevel},
14    PhaseReady, PhaseSetup,
15};
16
17mod rule001_heading_case;
18mod rule002_admonition_types;
19mod rule003_spelling;
20mod rule004_exclude_words;
21mod rule005_admonition_newlines;
22mod rule006_no_absolute_urls;
23
24pub use rule001_heading_case::Rule001HeadingCase;
25pub use rule002_admonition_types::Rule002AdmonitionTypes;
26pub use rule003_spelling::Rule003Spelling;
27pub use rule004_exclude_words::Rule004ExcludeWords;
28pub use rule005_admonition_newlines::Rule005AdmonitionNewlines;
29pub use rule006_no_absolute_urls::Rule006NoAbsoluteUrls;
30
31fn get_all_rules() -> Vec<Box<dyn Rule>> {
32    vec![
33        Box::new(Rule001HeadingCase::default()),
34        Box::new(Rule002AdmonitionTypes::default()),
35        Box::new(Rule003Spelling::default()),
36        Box::new(Rule004ExcludeWords::default()),
37        Box::new(Rule005AdmonitionNewlines),
38        Box::new(Rule006NoAbsoluteUrls::default()),
39    ]
40}
41
42pub(crate) trait Rule: Debug + RuleName {
43    fn default_level(&self) -> LintLevel;
44    fn setup(&mut self, _settings: Option<&mut RuleSettings>) {}
45    fn check(&self, ast: &Node, context: &Context, level: LintLevel) -> Option<Vec<LintError>>;
46}
47
48pub(crate) trait RuleName {
49    fn name(&self) -> &'static str;
50}
51
52impl dyn Rule {
53    pub fn get_level(&self, configured_level: Option<LintLevel>) -> LintLevel {
54        configured_level.unwrap_or(self.default_level())
55    }
56}
57
58#[derive(Clone, Debug)]
59pub(crate) struct RuleSettings(toml::Value);
60
61#[derive(Default)]
62pub(crate) struct RegexSettings {
63    pub(crate) beginning: Option<RegexBeginning>,
64    /// Regex should only match if it matches up to the end of the word.
65    pub(crate) ending: Option<RegexEnding>,
66}
67
68pub(crate) enum RegexBeginning {
69    VeryBeginning,
70    WordBoundary,
71}
72
73pub(crate) enum RegexEnding {
74    WordBoundary,
75}
76
77impl RuleSettings {
78    pub fn new(table: impl Into<toml::Table>) -> Self {
79        Self(toml::Value::Table(table.into()))
80    }
81
82    #[cfg(test)]
83    pub(crate) fn has_key(&self, key: &str) -> bool {
84        self.0
85            .as_table()
86            .map(|table| table.contains_key(key))
87            .unwrap_or(false)
88    }
89
90    #[cfg(test)]
91    pub(crate) fn from_key_value(key: &str, value: toml::Value) -> Self {
92        let mut table = toml::Table::new();
93        table.insert(key.to_string(), value);
94        Self::new(table)
95    }
96
97    #[cfg(test)]
98    pub(crate) fn with_array_of_strings(key: &str, values: Vec<&str>) -> Self {
99        let array = values
100            .into_iter()
101            .map(|s| toml::Value::String(s.to_string()))
102            .collect();
103        Self::from_key_value(key, toml::Value::Array(array))
104    }
105
106    fn get_array_of_strings(&self, key: &str) -> Option<Vec<String>> {
107        let table = &self.0;
108        if let Some(toml::Value::Array(array)) = table.get(key) {
109            let mut vec = Vec::new();
110            for value in array {
111                if let toml::Value::String(string) = value {
112                    vec.push(string.to_lowercase());
113                }
114            }
115
116            if vec.is_empty() {
117                return None;
118            } else {
119                return Some(vec);
120            }
121        }
122
123        None
124    }
125
126    fn get_array_of_regexes(
127        &self,
128        key: &str,
129        settings: Option<&RegexSettings>,
130    ) -> Option<Vec<Regex>> {
131        let table = &self.0;
132        if let Some(toml::Value::Array(array)) = table.get(key) {
133            let mut vec = Vec::new();
134            for value in array {
135                if let toml::Value::String(pattern) = value {
136                    let mut pattern = pattern.to_string();
137                    if let Some(settings) = settings {
138                        match settings.beginning {
139                            Some(RegexBeginning::VeryBeginning) => {
140                                if !pattern.starts_with('^') {
141                                    pattern = format!("^{}", pattern);
142                                }
143                            }
144                            Some(RegexBeginning::WordBoundary) => {
145                                if !pattern.starts_with("\\b")
146                                    && !pattern.starts_with("\\s")
147                                    && !pattern.starts_with("^")
148                                {
149                                    pattern = format!("(?:^|\\s|\\b){}", pattern);
150                                }
151                            }
152                            None => {}
153                        }
154                        #[allow(clippy::single_match)]
155                        match settings.ending {
156                            Some(RegexEnding::WordBoundary) => {
157                                if !pattern.ends_with("\\b")
158                                    && !pattern.ends_with("\\s")
159                                    && !pattern.ends_with("$")
160                                {
161                                    pattern = format!(r#"{}(?:\s|\b|$|[.,!?'"-])"#, pattern);
162                                }
163                            }
164                            None => {}
165                        }
166                    }
167
168                    if let Ok(regex) = Regex::new(&pattern) {
169                        vec.push(regex);
170                    } else {
171                        warn!("Encountered invalid regex pattern in rule settings: {pattern}")
172                    }
173                }
174            }
175            if vec.is_empty() {
176                None
177            } else {
178                // Sort regexes by length, so the longest match is tried first.
179                //
180                // This ensures, for example, that if two exceptions "Supabase"
181                // and "Supabase Auth" are defined, the "Supabase Auth"
182                // exception will trigger first, preventing "Auth" from being
183                // matched as a false positive.
184                //
185                // Note that this is not a perfect solution, as the order of
186                // matched pattern lengths is not guaranteed to be the same as
187                // the order of regex pattern lengths. For example, the regex
188                // "a{35}" is shorter than "abcdefg", but will match a longer
189                // result. However, since we're unlikely to see regexes defined
190                // this way in exception files, we're just going to ignore this
191                // issue for now.
192                vec.sort_by_key(|b| std::cmp::Reverse(b.as_str().len()));
193                Some(vec)
194            }
195        } else {
196            None
197        }
198    }
199
200    #[cfg(test)]
201    pub(crate) fn with_serializable<T: Serialize>(key: &str, value: &T) -> Self {
202        Self::from_key_value(key, toml::Value::try_from(value).unwrap())
203    }
204
205    // TODO: global config should not keep carrying around the rule-level configs after the rules are set up, because the rules could mutate it
206    fn get_deserializable<T: for<'de> Deserialize<'de>>(&mut self, key: &str) -> Option<T> {
207        if let toml::Value::Table(ref mut table) = self.0 {
208            if let Some(value) = table.remove(key) {
209                if let Ok(item) = value.try_into() {
210                    return Some(item);
211                }
212            }
213        }
214        None
215    }
216}
217
218pub(crate) type RuleFilter<'filter> = Option<&'filter [&'filter str]>;
219
220#[derive(Debug)]
221pub(crate) struct RuleRegistry<Phase> {
222    _phase: PhantomData<Phase>,
223    rules: Vec<Box<dyn Rule>>,
224    configured_levels: HashMap<String, LintLevel>,
225}
226
227impl RuleRegistry<PhaseSetup> {
228    pub fn new() -> Self {
229        Self {
230            _phase: PhantomData,
231            rules: get_all_rules(),
232            configured_levels: Default::default(),
233        }
234    }
235
236    pub fn save_configured_level(&mut self, rule_name: &str, level: LintLevel) {
237        self.configured_levels.insert(rule_name.to_string(), level);
238    }
239
240    pub fn setup(
241        mut self,
242        settings: &mut HashMap<String, RuleSettings>,
243    ) -> Result<RuleRegistry<PhaseReady>> {
244        for rule in &mut self.rules {
245            let rule_settings = settings.get_mut(rule.name());
246            rule.setup(rule_settings);
247        }
248
249        Ok(RuleRegistry {
250            _phase: PhantomData,
251            rules: self.rules,
252            configured_levels: self.configured_levels,
253        })
254    }
255}
256
257impl RuleRegistry<PhaseReady> {
258    pub fn run(&self, context: &Context) -> Result<Vec<LintError>> {
259        let mut errors = Vec::new();
260        self.check_node(context.parse_result.ast(), context, &mut errors);
261        Ok(errors)
262    }
263
264    fn check_node(&self, ast: &Node, context: &Context, errors: &mut Vec<LintError>) {
265        for rule in &self.rules {
266            if let Some(filter) = &context.check_only_rules {
267                if !filter.contains(&rule.name()) {
268                    continue;
269                }
270            }
271
272            let rule_level = rule.get_level(self.get_configured_level(rule.name()));
273            if let Some(rule_errors) = rule.check(ast, context, rule_level) {
274                debug!("Rule errors: {:#?}", rule_errors);
275                let filtered_errors: Vec<LintError> = rule_errors
276                    .into_iter()
277                    .filter(|err| {
278                        !context
279                            .disables
280                            .disabled_for_location(rule.name(), &err.location, context)
281                    })
282                    .collect();
283                errors.extend(filtered_errors);
284            }
285        }
286
287        if let Some(children) = ast.children() {
288            for child in children {
289                self.check_node(child, context, errors);
290            }
291        }
292    }
293}
294
295impl<State> RuleRegistry<State> {
296    pub fn is_valid_rule(&self, rule_name: &str) -> bool {
297        self.rules.iter().any(|rule| rule.name() == rule_name)
298    }
299
300    pub fn deactivate_rule(&mut self, rule_name: &str) {
301        self.rules.retain(|rule| rule.name() != rule_name);
302    }
303
304    pub fn get_configured_level(&self, rule_name: &str) -> Option<LintLevel> {
305        self.configured_levels.get(rule_name).cloned()
306    }
307
308    #[cfg(test)]
309    pub(crate) fn is_rule_active(&self, rule_name: &str) -> bool {
310        self.rules.iter().any(|rule| rule.name() == rule_name)
311    }
312
313    #[cfg(test)]
314    pub(crate) fn deactivate_all_but(&mut self, rule_name: &str) {
315        self.rules.retain(|rule| rule.name() == rule_name)
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use std::sync::{
322        atomic::{AtomicUsize, Ordering},
323        Arc,
324    };
325
326    use crate::parser::parse;
327
328    use super::*;
329    use markdown::mdast::Node;
330    use supa_mdx_macros::RuleName;
331
332    #[derive(Clone, Default, Debug, RuleName)]
333    struct MockRule {
334        check_count: Arc<AtomicUsize>,
335    }
336
337    impl Rule for MockRule {
338        fn default_level(&self) -> LintLevel {
339            LintLevel::Error
340        }
341
342        fn check(
343            &self,
344            _ast: &Node,
345            _context: &Context,
346            _level: LintLevel,
347        ) -> Option<Vec<LintError>> {
348            self.check_count.fetch_add(1, Ordering::Relaxed);
349            None
350        }
351    }
352
353    #[derive(Clone, Default, Debug, RuleName)]
354    struct MockRule2 {
355        check_count: Arc<AtomicUsize>,
356    }
357
358    impl Rule for MockRule2 {
359        fn default_level(&self) -> LintLevel {
360            LintLevel::Error
361        }
362
363        fn check(
364            &self,
365            _ast: &Node,
366            _context: &Context,
367            _level: LintLevel,
368        ) -> Option<Vec<LintError>> {
369            self.check_count.fetch_add(1, Ordering::Relaxed);
370            None
371        }
372    }
373
374    #[test]
375    fn test_check_node_with_filter() {
376        let mock_rule_1 = MockRule::default();
377        let mock_rule_2 = MockRule2::default();
378        let check_count_1 = mock_rule_1.check_count.clone();
379        let check_count_2 = mock_rule_2.check_count.clone();
380
381        let registry = RuleRegistry {
382            _phase: PhantomData,
383            rules: vec![Box::new(mock_rule_1), Box::new(mock_rule_2)],
384            configured_levels: Default::default(),
385        };
386
387        let mdx = "text";
388        let parse_result = parse(mdx).unwrap();
389        let context = Context::builder()
390            .parse_result(&parse_result)
391            .check_only_rules(&["MockRule"])
392            .build()
393            .unwrap();
394
395        let mut errors = Vec::new();
396        registry.check_node(parse_result.ast(), &context, &mut errors);
397
398        assert!(check_count_1.load(Ordering::Relaxed) > 1);
399        assert_eq!(check_count_2.load(Ordering::Relaxed), 0);
400    }
401
402    #[test]
403    fn test_check_node_without_filter() {
404        let mock_rule_1 = MockRule::default();
405        let mock_rule_2 = MockRule2::default();
406        let check_count_1 = mock_rule_1.check_count.clone();
407        let check_count_2 = mock_rule_2.check_count.clone();
408
409        let registry = RuleRegistry {
410            _phase: PhantomData,
411            rules: vec![Box::new(mock_rule_1), Box::new(mock_rule_2)],
412            configured_levels: Default::default(),
413        };
414
415        let mdx = "test";
416        let parse_result = parse(mdx).unwrap();
417        let context = Context::builder()
418            .parse_result(&parse_result)
419            .build()
420            .unwrap();
421
422        let mut errors = Vec::new();
423        registry.check_node(parse_result.ast(), &context, &mut errors);
424
425        assert!(check_count_1.load(Ordering::Relaxed) > 1);
426        assert!(check_count_2.load(Ordering::Relaxed) > 1);
427    }
428}