From 4309b55aa41cd84c4f2ec1099602816acef2a6c4 Mon Sep 17 00:00:00 2001 From: Chris W Date: Mon, 30 Oct 2023 20:41:16 -0600 Subject: [PATCH] add /detect route for language detection --- src/config.rs | 40 ++++++++++++++++++++++++--------- src/main.rs | 62 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 10 deletions(-) diff --git a/src/config.rs b/src/config.rs index 1be3fdf..82cfcb4 100755 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,5 @@ use anyhow::Error; +use std::collections::HashMap; use silicon::formatter::{ImageFormatter, ImageFormatterBuilder}; use silicon::utils::{Background, ShadowAdder}; use std::path::PathBuf; @@ -11,6 +12,15 @@ use crate::rgba::{ImageRgba, Rgba}; type FontList = Vec<(String, f32)>; type Lines = Vec; +macro_rules! unwrap_or_return { + ( $e:expr, $r:expr ) => { + match $e { + Ok(x) => x, + Err(_) => return $r, + } + }; +} + #[derive(Debug, serde::Deserialize)] pub struct Config { /// Background image URL @@ -132,15 +142,25 @@ impl Config { ps.find_syntax_by_first_line(first_line).unwrap_or_else(|| { // Try using tensorflow to detect the language let input_data = Tensor::new(&[1]).with_values(&[self.code.clone()]).unwrap(); - self.predict_language_with_tensorflow(ps, input_data) - .unwrap_or_else(|_| ps.find_syntax_by_token("log").unwrap()) + let predictions = self.predict_language_with_tensorflow(ps, input_data).unwrap(); + + let mut max_score = -std::f32::INFINITY; + let mut max_language = "log"; + for (language, score) in &predictions { // Borrow predictions here + if *score > max_score { + max_score = *score; + max_language = language; + } + } + + ps.find_syntax_by_token(max_language).unwrap_or_else(|| ps.find_syntax_by_token("log").unwrap()) }) }, }; Ok(language) } - pub fn predict_language_with_tensorflow<'a>(&self, ps: &'a SyntaxSet, input_data: Tensor) -> Result<&'a SyntaxReference, Error> { + pub fn predict_language_with_tensorflow<'a>(&self, ps: &'a SyntaxSet, input_data: Tensor) -> Result, Error> { if self.tf_model_graph.is_none() || self.tf_model.is_none() { return Err(Error::msg("TensorFlow model not loaded")); } @@ -165,14 +185,14 @@ impl Config { let classes: Tensor = args.fetch(output_token_classes)?; - // Find the index of the highest score - let max_index = scores.iter().enumerate().max_by(|a, b| a.1.partial_cmp(b.1).unwrap()).unwrap().0; - + let mut result: HashMap = HashMap::new(); + for (i, score) in scores.iter().enumerate() { + let class = classes[i].clone(); + let log_score = score.log2(); + result.insert(class, log_score); + } - let language = &classes[max_index]; - let language = ps.find_syntax_by_token(language).unwrap(); - - Ok(language) + Ok(result) } diff --git a/src/main.rs b/src/main.rs index 140d67d..7b0f930 100755 --- a/src/main.rs +++ b/src/main.rs @@ -7,6 +7,7 @@ use anyhow::Error; use lazy_static::lazy_static; use silicon as si; use silicon::utils::ToRgba; +use tensorflow::Tensor; use std::collections::HashSet; use std::io::Cursor; use std::num::ParseIntError; @@ -20,6 +21,7 @@ lazy_static! { static ref HIGHLIGHTING_ASSETS: si::assets::HighlightingAssets = silicon::assets::HighlightingAssets::new(); } + macro_rules! unwrap_or_return { ( $e:expr, $r:expr ) => { match $e { @@ -87,6 +89,12 @@ async fn help() -> impl Responder { "GET /themes": "Return a list of available syntax themes.", "GET /languages": "Retuns a list of languages which can be parsed.", "GET /fonts": "Returns a list of available fonts.", + "GET /detect": { + "description": "Detect the language of the given code.", + "parameters": { + "code": "The code to detect the language of. Required." + } + }, "GET /generate": { "description": "Generate an image from the given code.", "parameters": { @@ -150,6 +158,59 @@ async fn fonts() -> impl Responder { HttpResponse::Ok().json(fonts) } +#[get("/detect")] +async fn detect(info: web::Query) -> impl Responder { + let args = CliArgs::parse(); + let ha = &*HIGHLIGHTING_ASSETS; + + let (ps, _ts) = (&ha.syntax_set, &ha.theme_set); + + let mut conf = config::Config::default(); + conf.code = info.code.clone(); + if conf.code.is_empty() { + return HttpResponse::BadRequest() + .append_header(("Content-Type", "application/json")) + .body(r#"{"error": "code parameter is required"}"#); + } + + if args.tensorflow_model_dir.is_some() { + conf.load_tensorflow_model(args.tensorflow_model_dir.unwrap().as_str()); + } + + let input_data = Tensor::new(&[1]).with_values(&[conf.code.clone()]).unwrap(); + let predictions = unwrap_or_return!( + conf.predict_language_with_tensorflow(ps, input_data), + HttpResponse::BadRequest() + .append_header(("Content-Type", "application/json")) + .body(r#"{"error": "Failed to detect language."}"#) + ); + + let mut sorted_predictions: Vec<_> = predictions.iter().collect(); + sorted_predictions.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); + + let min_score = predictions.iter().map(|(_, score)| *score).fold(f32::INFINITY, f32::min); + let max_score = predictions.iter().map(|(_, score)| *score).fold(f32::NEG_INFINITY, f32::max); + + // Normalize scores and pick top 5 + let mut normalized_predictions: Vec<_> = predictions.iter().map(|(lang, score)| { + let normalized_score = (score - min_score) / (max_score - min_score) * 100.0; + (lang, normalized_score) + }).collect(); + + normalized_predictions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + + let response = normalized_predictions + .iter() + // .take(5) + .map(|(language, score)| format!("{{\"language\": \"{}\", \"score\": {}}}", language, score)) + .collect::>() + .join(","); + + HttpResponse::Ok() + .append_header(("Content-Type", "application/json")) + .body(format!("[{}]", response)) +} + #[get("/generate")] async fn generate(info: web::Query) -> impl Responder { let args = CliArgs::parse(); @@ -290,6 +351,7 @@ async fn main() -> std::io::Result<()> { .service(themes) .service(languages) .service(fonts) + .service(detect) .service(generate) }) .bind((host.clone(), port.parse::().unwrap()))?