1use anyhow::Result;
2use log::{debug, warn};
3use markdown::mdast::Node;
4use regex::Regex;
5use serde::Deserialize;
6use std::{collections::HashMap, fmt::Debug, marker::PhantomData};
7
8#[cfg(test)]
9use serde::Serialize;
10
11use crate::{
12 context::Context,
13 errors::{LintError, LintLevel},
14 PhaseReady, PhaseSetup,
15};
16
17mod rule001_heading_case;
18mod rule002_admonition_types;
19mod rule003_spelling;
20mod rule004_exclude_words;
21mod rule005_admonition_newlines;
22mod rule006_no_absolute_urls;
23
24pub use rule001_heading_case::Rule001HeadingCase;
25pub use rule002_admonition_types::Rule002AdmonitionTypes;
26pub use rule003_spelling::Rule003Spelling;
27pub use rule004_exclude_words::Rule004ExcludeWords;
28pub use rule005_admonition_newlines::Rule005AdmonitionNewlines;
29pub use rule006_no_absolute_urls::Rule006NoAbsoluteUrls;
30
31fn get_all_rules() -> Vec<Box<dyn Rule>> {
32 vec![
33 Box::new(Rule001HeadingCase::default()),
34 Box::new(Rule002AdmonitionTypes::default()),
35 Box::new(Rule003Spelling::default()),
36 Box::new(Rule004ExcludeWords::default()),
37 Box::new(Rule005AdmonitionNewlines),
38 Box::new(Rule006NoAbsoluteUrls::default()),
39 ]
40}
41
42pub(crate) trait Rule: Debug + RuleName {
43 fn default_level(&self) -> LintLevel;
44 fn setup(&mut self, _settings: Option<&mut RuleSettings>) {}
45 fn check(&self, ast: &Node, context: &Context, level: LintLevel) -> Option<Vec<LintError>>;
46}
47
48pub(crate) trait RuleName {
49 fn name(&self) -> &'static str;
50}
51
52impl dyn Rule {
53 pub fn get_level(&self, configured_level: Option<LintLevel>) -> LintLevel {
54 configured_level.unwrap_or(self.default_level())
55 }
56}
57
58#[derive(Clone, Debug)]
59pub(crate) struct RuleSettings(toml::Value);
60
61#[derive(Default)]
62pub(crate) struct RegexSettings {
63 pub(crate) beginning: Option<RegexBeginning>,
64 pub(crate) ending: Option<RegexEnding>,
66}
67
68pub(crate) enum RegexBeginning {
69 VeryBeginning,
70 WordBoundary,
71}
72
73pub(crate) enum RegexEnding {
74 WordBoundary,
75}
76
77impl RuleSettings {
78 pub fn new(table: impl Into<toml::Table>) -> Self {
79 Self(toml::Value::Table(table.into()))
80 }
81
82 #[cfg(test)]
83 pub(crate) fn has_key(&self, key: &str) -> bool {
84 self.0
85 .as_table()
86 .map(|table| table.contains_key(key))
87 .unwrap_or(false)
88 }
89
90 #[cfg(test)]
91 pub(crate) fn from_key_value(key: &str, value: toml::Value) -> Self {
92 let mut table = toml::Table::new();
93 table.insert(key.to_string(), value);
94 Self::new(table)
95 }
96
97 #[cfg(test)]
98 pub(crate) fn with_array_of_strings(key: &str, values: Vec<&str>) -> Self {
99 let array = values
100 .into_iter()
101 .map(|s| toml::Value::String(s.to_string()))
102 .collect();
103 Self::from_key_value(key, toml::Value::Array(array))
104 }
105
106 fn get_array_of_strings(&self, key: &str) -> Option<Vec<String>> {
107 let table = &self.0;
108 if let Some(toml::Value::Array(array)) = table.get(key) {
109 let mut vec = Vec::new();
110 for value in array {
111 if let toml::Value::String(string) = value {
112 vec.push(string.to_lowercase());
113 }
114 }
115
116 if vec.is_empty() {
117 return None;
118 } else {
119 return Some(vec);
120 }
121 }
122
123 None
124 }
125
126 fn get_array_of_regexes(
127 &self,
128 key: &str,
129 settings: Option<&RegexSettings>,
130 ) -> Option<Vec<Regex>> {
131 let table = &self.0;
132 if let Some(toml::Value::Array(array)) = table.get(key) {
133 let mut vec = Vec::new();
134 for value in array {
135 if let toml::Value::String(pattern) = value {
136 let mut pattern = pattern.to_string();
137 if let Some(settings) = settings {
138 match settings.beginning {
139 Some(RegexBeginning::VeryBeginning) => {
140 if !pattern.starts_with('^') {
141 pattern = format!("^{}", pattern);
142 }
143 }
144 Some(RegexBeginning::WordBoundary) => {
145 if !pattern.starts_with("\\b")
146 && !pattern.starts_with("\\s")
147 && !pattern.starts_with("^")
148 {
149 pattern = format!("(?:^|\\s|\\b){}", pattern);
150 }
151 }
152 None => {}
153 }
154 #[allow(clippy::single_match)]
155 match settings.ending {
156 Some(RegexEnding::WordBoundary) => {
157 if !pattern.ends_with("\\b")
158 && !pattern.ends_with("\\s")
159 && !pattern.ends_with("$")
160 {
161 pattern = format!(r#"{}(?:\s|\b|$|[.,!?'"-])"#, pattern);
162 }
163 }
164 None => {}
165 }
166 }
167
168 if let Ok(regex) = Regex::new(&pattern) {
169 vec.push(regex);
170 } else {
171 warn!("Encountered invalid regex pattern in rule settings: {pattern}")
172 }
173 }
174 }
175 if vec.is_empty() {
176 None
177 } else {
178 vec.sort_by_key(|b| std::cmp::Reverse(b.as_str().len()));
193 Some(vec)
194 }
195 } else {
196 None
197 }
198 }
199
200 #[cfg(test)]
201 pub(crate) fn with_serializable<T: Serialize>(key: &str, value: &T) -> Self {
202 Self::from_key_value(key, toml::Value::try_from(value).unwrap())
203 }
204
205 fn get_deserializable<T: for<'de> Deserialize<'de>>(&mut self, key: &str) -> Option<T> {
207 if let toml::Value::Table(ref mut table) = self.0 {
208 if let Some(value) = table.remove(key) {
209 if let Ok(item) = value.try_into() {
210 return Some(item);
211 }
212 }
213 }
214 None
215 }
216}
217
218pub(crate) type RuleFilter<'filter> = Option<&'filter [&'filter str]>;
219
220#[derive(Debug)]
221pub(crate) struct RuleRegistry<Phase> {
222 _phase: PhantomData<Phase>,
223 rules: Vec<Box<dyn Rule>>,
224 configured_levels: HashMap<String, LintLevel>,
225}
226
227impl RuleRegistry<PhaseSetup> {
228 pub fn new() -> Self {
229 Self {
230 _phase: PhantomData,
231 rules: get_all_rules(),
232 configured_levels: Default::default(),
233 }
234 }
235
236 pub fn save_configured_level(&mut self, rule_name: &str, level: LintLevel) {
237 self.configured_levels.insert(rule_name.to_string(), level);
238 }
239
240 pub fn setup(
241 mut self,
242 settings: &mut HashMap<String, RuleSettings>,
243 ) -> Result<RuleRegistry<PhaseReady>> {
244 for rule in &mut self.rules {
245 let rule_settings = settings.get_mut(rule.name());
246 rule.setup(rule_settings);
247 }
248
249 Ok(RuleRegistry {
250 _phase: PhantomData,
251 rules: self.rules,
252 configured_levels: self.configured_levels,
253 })
254 }
255}
256
257impl RuleRegistry<PhaseReady> {
258 pub fn run(&self, context: &Context) -> Result<Vec<LintError>> {
259 let mut errors = Vec::new();
260 self.check_node(context.parse_result.ast(), context, &mut errors);
261 Ok(errors)
262 }
263
264 fn check_node(&self, ast: &Node, context: &Context, errors: &mut Vec<LintError>) {
265 for rule in &self.rules {
266 if let Some(filter) = &context.check_only_rules {
267 if !filter.contains(&rule.name()) {
268 continue;
269 }
270 }
271
272 let rule_level = rule.get_level(self.get_configured_level(rule.name()));
273 if let Some(rule_errors) = rule.check(ast, context, rule_level) {
274 debug!("Rule errors: {:#?}", rule_errors);
275 let filtered_errors: Vec<LintError> = rule_errors
276 .into_iter()
277 .filter(|err| {
278 !context
279 .disables
280 .disabled_for_location(rule.name(), &err.location, context)
281 })
282 .collect();
283 errors.extend(filtered_errors);
284 }
285 }
286
287 if let Some(children) = ast.children() {
288 for child in children {
289 self.check_node(child, context, errors);
290 }
291 }
292 }
293}
294
295impl<State> RuleRegistry<State> {
296 pub fn is_valid_rule(&self, rule_name: &str) -> bool {
297 self.rules.iter().any(|rule| rule.name() == rule_name)
298 }
299
300 pub fn deactivate_rule(&mut self, rule_name: &str) {
301 self.rules.retain(|rule| rule.name() != rule_name);
302 }
303
304 pub fn get_configured_level(&self, rule_name: &str) -> Option<LintLevel> {
305 self.configured_levels.get(rule_name).cloned()
306 }
307
308 #[cfg(test)]
309 pub(crate) fn is_rule_active(&self, rule_name: &str) -> bool {
310 self.rules.iter().any(|rule| rule.name() == rule_name)
311 }
312
313 #[cfg(test)]
314 pub(crate) fn deactivate_all_but(&mut self, rule_name: &str) {
315 self.rules.retain(|rule| rule.name() == rule_name)
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use std::sync::{
322 atomic::{AtomicUsize, Ordering},
323 Arc,
324 };
325
326 use crate::parser::parse;
327
328 use super::*;
329 use markdown::mdast::Node;
330 use supa_mdx_macros::RuleName;
331
332 #[derive(Clone, Default, Debug, RuleName)]
333 struct MockRule {
334 check_count: Arc<AtomicUsize>,
335 }
336
337 impl Rule for MockRule {
338 fn default_level(&self) -> LintLevel {
339 LintLevel::Error
340 }
341
342 fn check(
343 &self,
344 _ast: &Node,
345 _context: &Context,
346 _level: LintLevel,
347 ) -> Option<Vec<LintError>> {
348 self.check_count.fetch_add(1, Ordering::Relaxed);
349 None
350 }
351 }
352
353 #[derive(Clone, Default, Debug, RuleName)]
354 struct MockRule2 {
355 check_count: Arc<AtomicUsize>,
356 }
357
358 impl Rule for MockRule2 {
359 fn default_level(&self) -> LintLevel {
360 LintLevel::Error
361 }
362
363 fn check(
364 &self,
365 _ast: &Node,
366 _context: &Context,
367 _level: LintLevel,
368 ) -> Option<Vec<LintError>> {
369 self.check_count.fetch_add(1, Ordering::Relaxed);
370 None
371 }
372 }
373
374 #[test]
375 fn test_check_node_with_filter() {
376 let mock_rule_1 = MockRule::default();
377 let mock_rule_2 = MockRule2::default();
378 let check_count_1 = mock_rule_1.check_count.clone();
379 let check_count_2 = mock_rule_2.check_count.clone();
380
381 let registry = RuleRegistry {
382 _phase: PhantomData,
383 rules: vec![Box::new(mock_rule_1), Box::new(mock_rule_2)],
384 configured_levels: Default::default(),
385 };
386
387 let mdx = "text";
388 let parse_result = parse(mdx).unwrap();
389 let context = Context::builder()
390 .parse_result(&parse_result)
391 .check_only_rules(&["MockRule"])
392 .build()
393 .unwrap();
394
395 let mut errors = Vec::new();
396 registry.check_node(parse_result.ast(), &context, &mut errors);
397
398 assert!(check_count_1.load(Ordering::Relaxed) > 1);
399 assert_eq!(check_count_2.load(Ordering::Relaxed), 0);
400 }
401
402 #[test]
403 fn test_check_node_without_filter() {
404 let mock_rule_1 = MockRule::default();
405 let mock_rule_2 = MockRule2::default();
406 let check_count_1 = mock_rule_1.check_count.clone();
407 let check_count_2 = mock_rule_2.check_count.clone();
408
409 let registry = RuleRegistry {
410 _phase: PhantomData,
411 rules: vec![Box::new(mock_rule_1), Box::new(mock_rule_2)],
412 configured_levels: Default::default(),
413 };
414
415 let mdx = "test";
416 let parse_result = parse(mdx).unwrap();
417 let context = Context::builder()
418 .parse_result(&parse_result)
419 .build()
420 .unwrap();
421
422 let mut errors = Vec::new();
423 registry.check_node(parse_result.ast(), &context, &mut errors);
424
425 assert!(check_count_1.load(Ordering::Relaxed) > 1);
426 assert!(check_count_2.load(Ordering::Relaxed) > 1);
427 }
428}