diff --git a/crates/cli/src/commands/apply_pattern.rs b/crates/cli/src/commands/apply_pattern.rs index 4cd7bf6b4..f42fb20a1 100644 --- a/crates/cli/src/commands/apply_pattern.rs +++ b/crates/cli/src/commands/apply_pattern.rs @@ -237,7 +237,7 @@ pub(crate) async fn run_apply_pattern( "A path must have an extension to determine the language for stdin" ))?; if let Some(ext) = ext.to_str() { - PatternLanguage::from_extension(ext) + PatternLanguage::from_string_or_alias(ext, None) } else { default_lang } diff --git a/crates/core/src/pattern_compiler/compiler.rs b/crates/core/src/pattern_compiler/compiler.rs index a99d9426c..75a44d0fd 100644 --- a/crates/core/src/pattern_compiler/compiler.rs +++ b/crates/core/src/pattern_compiler/compiler.rs @@ -673,9 +673,23 @@ mod tests { let pattern_tsx = "language js(jsx)"; let pattern_default = "language js"; let pattern_default_fall_through = "language js(block)"; + let pattern_python_alias = "language py"; + let pattern_csharp_alias = "language cs"; + let pattern_rust_alias = "language rs"; + let pattern_ruby_alias = "language rb"; + let pattern_solidity_alias = "language sol"; + let pattern_hcl_alias = "language tf"; + let pattern_yaml_alias = "language yml"; let js: TargetLanguage = PatternLanguage::JavaScript.try_into().unwrap(); let ts: TargetLanguage = PatternLanguage::TypeScript.try_into().unwrap(); let tsx: TargetLanguage = PatternLanguage::Tsx.try_into().unwrap(); + let python: TargetLanguage = PatternLanguage::Python.try_into().unwrap(); + let csharp: TargetLanguage = PatternLanguage::CSharp.try_into().unwrap(); + let rust: TargetLanguage = PatternLanguage::Rust.try_into().unwrap(); + let ruby: TargetLanguage = PatternLanguage::Ruby.try_into().unwrap(); + let solidity: TargetLanguage = PatternLanguage::Solidity.try_into().unwrap(); + let hcl: TargetLanguage = PatternLanguage::Hcl.try_into().unwrap(); + let yaml: TargetLanguage = PatternLanguage::Yaml.try_into().unwrap(); assert_eq!( TargetLanguage::get_language(pattern_javascript) .unwrap() @@ -706,5 +720,47 @@ mod tests { .language_name(), tsx.language_name() ); + assert_eq!( + TargetLanguage::get_language(pattern_python_alias) + .unwrap() + .language_name(), + python.language_name() + ); + assert_eq!( + TargetLanguage::get_language(pattern_csharp_alias) + .unwrap() + .language_name(), + csharp.language_name() + ); + assert_eq!( + TargetLanguage::get_language(pattern_rust_alias) + .unwrap() + .language_name(), + rust.language_name() + ); + assert_eq!( + TargetLanguage::get_language(pattern_ruby_alias) + .unwrap() + .language_name(), + ruby.language_name() + ); + assert_eq!( + TargetLanguage::get_language(pattern_solidity_alias) + .unwrap() + .language_name(), + solidity.language_name() + ); + assert_eq!( + TargetLanguage::get_language(pattern_hcl_alias) + .unwrap() + .language_name(), + hcl.language_name() + ); + assert_eq!( + TargetLanguage::get_language(pattern_yaml_alias) + .unwrap() + .language_name(), + yaml.language_name() + ); } } diff --git a/crates/language/src/target_language.rs b/crates/language/src/target_language.rs index 6b6040dcf..45fa182cf 100644 --- a/crates/language/src/target_language.rs +++ b/crates/language/src/target_language.rs @@ -129,6 +129,48 @@ impl PatternLanguage { Self::from_string(lang, flavor.as_deref()) } +impl PatternLanguage { + pub fn from_string(name: &str, flavor: Option<&str>) -> Option { + match name { + "py" | "python" => Some(Self::Python), + "js" | "javascript" => match flavor { + Some("jsx") => Some(Self::Tsx), + Some("flow") => Some(Self::Tsx), + Some("FlowComments") => Some(Self::Tsx), + Some("typescript") => Some(Self::TypeScript), + Some("js_do_not_use") => Some(Self::JavaScript), + _ => Some(Self::Tsx), + }, + "ts" | "typescript" => Some(Self::Typescript), + "html" => Some(Self::Html), + "css" => Some(Self::Css), + "json" => Some(Self::Json), + "java" => Some(Self::Java), + "cs" | "csharp" => Some(Self::CSharp), + "md" | "markdown" => match flavor { + Some("block") => Some(Self::MarkdownBlock), + Some("inline") => Some(Self::MarkdownInline), + _ => Some(Self::MarkdownInline), + }, + "go" => Some(Self::Go), + "rs" | "rust" => Some(Self::Rust), + "rb" | "ruby" => Some(Self::Ruby), + "sol" | "solidity" => Some(Self::Solidity), + "hcl" | "tf" | "terraform" => Some(Self::Hcl), + "yml" | "yaml" => Some(Self::Yaml), + "sql" => Some(Self::Sql), + "vue" => Some(Self::Vue), + "toml" => Some(Self::Toml), + "php" => match flavor { + Some("html") => Some(Self::Php), + Some("only") => Some(self::PhpOnly), + _ => Some(Self::Php), + }, + "universal" => Some(Self::Universal), + _ => None, + } + } + #[cfg(not(feature = "builtin-parser"))] pub fn get_language_with_parser(_parser: &mut MarzanoGritParser, _body: &str) -> Option { unimplemented!("grit_parser is unavailable when feature flag [builtin-parser] is off.") @@ -147,7 +189,8 @@ impl PatternLanguage { pub fn from_string(name: &str, flavor: Option<&str>) -> Option { match name { - "js" => match flavor { + "py" | "python" => Some(Self::Python), + "js" | "javascript" => match flavor { Some("jsx") => Some(Self::Tsx), Some("flow") => Some(Self::Tsx), Some("flowComments") => Some(Self::Tsx), @@ -155,35 +198,34 @@ impl PatternLanguage { Some("js_do_not_use") => Some(Self::JavaScript), _ => Some(Self::Tsx), }, - "html" => Some(Self::Html), - "css" => Some(Self::Css), - "json" => Some(Self::Json), - "java" => Some(Self::Java), - "csharp" => Some(Self::CSharp), - "markdown" => match flavor { - Some("block") => Some(Self::MarkdownBlock), - Some("inline") => Some(Self::MarkdownInline), - _ => Some(Self::MarkdownInline), - }, - "ipynb" => Some(Self::Python), - "python" => Some(Self::Python), - "go" => Some(Self::Go), - "rust" => Some(Self::Rust), - "ruby" => Some(Self::Ruby), - "sol" | "solidity" => Some(Self::Solidity), - "hcl" => Some(Self::Hcl), - "yaml" => Some(Self::Yaml), - "sql" => Some(Self::Sql), - "vue" => Some(Self::Vue), - "toml" => Some(Self::Toml), - "php" => match flavor { - Some("html") => Some(Self::Php), - Some("only") => Some(Self::PhpOnly), - _ => Some(Self::Php), - }, - "universal" => Some(Self::Universal), - _ => None, - } + "ts" | "typescript" => Some(Self::TypeScript), + "html" => Some(Self::Html), + "css" => Some(Self::Css), + "json" => Some(Self::Json), + "java" => Some(Self::Java), + "cs" | "csharp" => Some(Self::CSharp), + "md" | "markdown" => match flavor { + Some("block") => Some(Self::MarkdownBlock), + Some("inline") => Some(Self::MarkdownInline), + _ => Some(Self::MarkdownInline), + }, + "go" => Some(Self::Go), + "rs" | "rust" => Some(Self::Rust), + "rb" | "ruby" => Some(Self::Ruby), + "sol" | "solidity" => Some(Self::Solidity), + "hcl" | "tf" | "terraform" => Some(Self::Hcl), + "yml" | "yaml" => Some(Self::Yaml), + "sql" => Some(Self::Sql), + "vue" => Some(Self::Vue), + "toml" => Some(Self::Toml), + "php" => match flavor { + Some("html") => Some(Self::Php), + Some("only") => Some(Self::PhpOnly), + _ => Some(Self::Php), + }, + "universal" => Some(Self::Universal), + _ => None, + } } fn get_file_extensions(&self) -> &'static [&'static str] { @@ -245,6 +287,7 @@ impl PatternLanguage { } pub fn from_extension(extension: &str) -> Option { + Self::from_string(extension, None) match extension { "js" | "jsx" | "cjs" | "mjs" => Some(Self::Tsx), "ts" | "tsx" | "cts" | "mts" => Some(Self::Tsx), @@ -822,4 +865,45 @@ mod tests { let other_comment = lang.extract_single_line_comment(other_text).unwrap(); assert_eq!(other_comment, "this is a comment"); } + + #[test] + fn test_from_string_or_alias() { + assert_eq!(PatternLanguage::from_string_or_alias("py", None), Some(PatternLanguage::Python)); + assert_eq!(PatternLanguage::from_string_or_alias("python", None), Some(PatternLanguage::Python)); + assert_eq!(PatternLanguage::from_string_or_alias("js", None), Some(PatternLanguage::Tsx)); + assert_eq!(PatternLanguage::from_string_or_alias("javascript", None), Some(PatternLanguage::Tsx)); + assert_eq!(PatternLanguage::from_string_or_alias("ts", None), Some(PatternLanguage::TypeScript)); + assert_eq!(PatternLanguage::from_string_or_alias("typescript", None), Some(PatternLanguage::TypeScript)); + assert_eq!(PatternLanguage::from_string_or_alias("cs", None), Some(PatternLanguage::CSharp)); + assert_eq!(PatternLanguage::from_string_or_alias("csharp", None), Some(PatternLanguage::CSharp)); + assert_eq!(PatternLanguage::from_string_or_alias("md", None), Some(PatternLanguage::MarkdownInline)); + assert_eq!(PatternLanguage::from_string_or_alias("markdown", None), Some(PatternLanguage::MarkdownInline)); + assert_eq!(PatternLanguage::from_string_or_alias("rs", None), Some(PatternLanguage::Rust)); + assert_eq!(PatternLanguage::from_string_or_alias("rust", None), Some(PatternLanguage::Rust)); + assert_eq!(PatternLanguage::from_string_or_alias("rb", None), Some(PatternLanguage::Ruby)); + assert_eq!(PatternLanguage::from_string_or_alias("ruby", None), Some(PatternLanguage::Ruby)); + assert_eq!(PatternLanguage::from_string_or_alias("sol", None), Some(PatternLanguage::Solidity)); + assert_eq!(PatternLanguage::from_string_or_alias("solidity", None), Some(PatternLanguage::Solidity)); + assert_eq!(PatternLanguage::from_string_or_alias("tf", None), Some(PatternLanguage::Hcl)); + assert_eq!(PatternLanguage::from_string_or_alias("hcl", None), Some(PatternLanguage::Hcl)); + assert_eq!(PatternLanguage::from_string_or_alias("terraform", None), Some(PatternLanguage::Hcl)); + assert_eq!(PatternLanguage::from_string_or_alias("yml", None), Some(PatternLanguage::Yaml)); + assert_eq!(PatternLanguage::from_string_or_alias("yaml", None), Some(PatternLanguage::Yaml)); + assert_eq!(PatternLanguage::from_string_or_alias("unknown", None), None); + } + + #[test] + fn test_from_extension() { + assert_eq!(PatternLanguage::from_extension("py"), Some(PatternLanguage::Python)); + assert_eq!(PatternLanguage::from_extension("js"), Some(PatternLanguage::Tsx)); + assert_eq!(PatternLanguage::from_extension("ts"), Some(PatternLanguage::TypeScript)); + assert_eq!(PatternLanguage::from_extension("cs"), Some(PatternLanguage::CSharp)); + assert_eq!(PatternLanguage::from_extension("md"), Some(PatternLanguage::MarkdownInline)); + assert_eq!(PatternLanguage::from_extension("rs"), Some(PatternLanguage::Rust)); + assert_eq!(PatternLanguage::from_extension("rb"), Some(PatternLanguage::Ruby)); + assert_eq!(PatternLanguage::from_extension("sol"), Some(PatternLanguage::Solidity)); + assert_eq!(PatternLanguage::from_extension("tf"), Some(PatternLanguage::Hcl)); + assert_eq!(PatternLanguage::from_extension("yml"), Some(PatternLanguage::Yaml)); + assert_eq!(PatternLanguage::from_extension("unknown"), None); + } }